diff --git a/openfeature/client.py b/openfeature/client.py index 1ccee33b..b27866ce 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -425,6 +425,7 @@ def _create_provider_evaluation( raise GeneralError(error_message="Unknown flag type") resolution = get_details_callable(*args) + resolution.raise_for_error() # we need to check the get_args to be compatible with union types. _typecheck_flag_value(resolution.value, flag_type) diff --git a/openfeature/exception.py b/openfeature/exception.py index e6ad2456..d17c28fb 100644 --- a/openfeature/exception.py +++ b/openfeature/exception.py @@ -1,18 +1,10 @@ +from __future__ import annotations + import typing +from collections.abc import Mapping from enum import Enum -class ErrorCode(Enum): - PROVIDER_NOT_READY = "PROVIDER_NOT_READY" - PROVIDER_FATAL = "PROVIDER_FATAL" - FLAG_NOT_FOUND = "FLAG_NOT_FOUND" - PARSE_ERROR = "PARSE_ERROR" - TYPE_MISMATCH = "TYPE_MISMATCH" - TARGETING_KEY_MISSING = "TARGETING_KEY_MISSING" - INVALID_CONTEXT = "INVALID_CONTEXT" - GENERAL = "GENERAL" - - class OpenFeatureError(Exception): """ A generic open feature exception, this exception should not be raised. Instead @@ -156,3 +148,32 @@ def __init__(self, error_message: typing.Optional[str]): raised """ super().__init__(ErrorCode.INVALID_CONTEXT, error_message) + + +class ErrorCode(Enum): + PROVIDER_NOT_READY = "PROVIDER_NOT_READY" + PROVIDER_FATAL = "PROVIDER_FATAL" + FLAG_NOT_FOUND = "FLAG_NOT_FOUND" + PARSE_ERROR = "PARSE_ERROR" + TYPE_MISMATCH = "TYPE_MISMATCH" + TARGETING_KEY_MISSING = "TARGETING_KEY_MISSING" + INVALID_CONTEXT = "INVALID_CONTEXT" + GENERAL = "GENERAL" + + __exceptions__: Mapping[str, typing.Callable[[str], OpenFeatureError]] = { + PROVIDER_NOT_READY: ProviderNotReadyError, + PROVIDER_FATAL: ProviderFatalError, + FLAG_NOT_FOUND: FlagNotFoundError, + PARSE_ERROR: ParseError, + TYPE_MISMATCH: TypeMismatchError, + TARGETING_KEY_MISSING: TargetingKeyMissingError, + INVALID_CONTEXT: InvalidContextError, + GENERAL: GeneralError, + } + + @classmethod + def to_exception( + cls, error_code: ErrorCode, error_message: str + ) -> OpenFeatureError: + exc = cls.__exceptions__.get(error_code.value, GeneralError) + return exc(error_message) diff --git a/openfeature/flag_evaluation.py b/openfeature/flag_evaluation.py index 98adab4b..86233ed2 100644 --- a/openfeature/flag_evaluation.py +++ b/openfeature/flag_evaluation.py @@ -63,3 +63,8 @@ class FlagResolutionDetails(typing.Generic[U_co]): reason: typing.Optional[typing.Union[str, Reason]] = None variant: typing.Optional[str] = None flag_metadata: FlagMetadata = field(default_factory=dict) + + def raise_for_error(self) -> None: + if self.error_code: + raise ErrorCode.to_exception(self.error_code, self.error_message or "") + return None diff --git a/tests/test_client.py b/tests/test_client.py index 5f710609..43223d99 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,12 +2,12 @@ import pytest -from openfeature.api import add_hooks, clear_hooks, set_provider +from openfeature.api import add_hooks, clear_hooks, get_client, set_provider from openfeature.client import OpenFeatureClient from openfeature.exception import ErrorCode, OpenFeatureError -from openfeature.flag_evaluation import Reason +from openfeature.flag_evaluation import FlagResolutionDetails, Reason from openfeature.hook import Hook -from openfeature.provider import ProviderStatus +from openfeature.provider import FeatureProvider, ProviderStatus from openfeature.provider.in_memory_provider import InMemoryFlag, InMemoryProvider from openfeature.provider.no_op_provider import NoOpProvider @@ -236,3 +236,27 @@ def test_should_shortcircuit_if_provider_is_in_irrecoverable_error_state( assert flag_details.reason == Reason.ERROR assert flag_details.error_code == ErrorCode.PROVIDER_FATAL spy_hook.error.assert_called_once() + + +def test_should_run_error_hooks_if_provider_returns_resolution_with_error_code(): + # Given + spy_hook = MagicMock(spec=Hook) + provider = MagicMock(spec=FeatureProvider) + provider.get_provider_hooks.return_value = [] + provider.resolve_boolean_details.return_value = FlagResolutionDetails( + value=True, + reason=Reason.ERROR, + error_code=ErrorCode.PROVIDER_FATAL, + error_message="This is an error message", + ) + set_provider(provider) + client = get_client() + client.add_hooks([spy_hook]) + # When + flag_details = client.get_boolean_details(flag_key="Key", default_value=True) + # Then + assert flag_details is not None + assert flag_details.value + assert flag_details.reason == Reason.ERROR + assert flag_details.error_code == ErrorCode.PROVIDER_FATAL + spy_hook.error.assert_called_once()