Skip to content

Commit

Permalink
Update Holder tests to use import, add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Morendil committed Apr 5, 2019
1 parent f61835e commit f000f22
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions openfisca_core/indexed_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ def encode(cls, array):
if array.dtype.kind in {'U', 'S'}: # String array
array = np.select([array == item.name for item in cls], [item.index for item in cls]).astype(ENUM_ARRAY_DTYPE)
elif array.dtype.kind == 'O': # Enum items arrays
# Ensure we are comparing the comparable. The problem this fixes:
# On entering "cls" this method will generally come from variable.possible_values,
# while the array values may come from directly importing a module containing an Enum class.
# However, variables (and hence their possible_values) are loaded by a call to load_module,
# which gives them a different identity from the ones imported in the usual way.
# So, instead of relying on the "cls" passed in, we use only its name to check that
# the values in the array, if non-empty, are of the right type.
if len(array) > 0 and cls.__name__ is array[0].__class__.__name__:
cls = array[0].__class__
array = np.select([array == item for item in cls], [item.index for item in cls]).astype(ENUM_ARRAY_DTYPE)
return EnumArray(array, cls)

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import openfisca_country_template.situation_examples
from openfisca_core.simulation_builder import SimulationBuilder
from openfisca_country_template.variables.housing import HousingOccupancyStatus
from openfisca_core.periods import period as make_period, ETERNITY
from openfisca_core.tools import assert_near
from openfisca_core.memory_config import MemoryConfig
Expand All @@ -26,7 +27,6 @@ def couple():


period = make_period('2017-12')
HousingOccupancyStatus = tax_benefit_system.get_variable('housing_occupancy_status').possible_values


def test_set_input_enum_string(couple):
Expand Down

0 comments on commit f000f22

Please sign in to comment.