diff --git a/.changes/unreleased/Under the Hood-20230130-180917.yaml b/.changes/unreleased/Under the Hood-20230130-180917.yaml new file mode 100644 index 00000000000..64c35d67f12 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230130-180917.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: warn_error/warn_error_options mutual exclusivity in click +time: 2023-01-30T18:09:17.240662-05:00 +custom: + Author: michelleark + Issue: "6579" diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index a8af62bf61d..dcfb59507c5 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -5,9 +5,9 @@ from importlib import import_module from multiprocessing import get_context from pprint import pformat as pf -from typing import Set +from typing import Set, List -from click import Context, get_current_context +from click import Context, get_current_context, BadOptionUsage from click.core import ParameterSource from dbt.config.profile import read_user_config @@ -59,12 +59,15 @@ def assign_params(ctx, params_assigned_from_default): # Overwrite default assignments with user config if available if user_config: + param_assigned_from_default_copy = params_assigned_from_default.copy() for param_assigned_from_default in params_assigned_from_default: user_config_param_value = getattr(user_config, param_assigned_from_default, None) if user_config_param_value is not None: object.__setattr__( self, param_assigned_from_default.upper(), user_config_param_value ) + param_assigned_from_default_copy.remove(param_assigned_from_default) + params_assigned_from_default = param_assigned_from_default_copy # Hard coded flags object.__setattr__(self, "WHICH", invoked_subcommand_name or ctx.info_name) @@ -78,6 +81,10 @@ def assign_params(ctx, params_assigned_from_default): if os.getenv("DO_NOT_TRACK", "").lower() in ("1", "t", "true", "y", "yes") else True, ) + # Check mutual exclusivity once all flags are set + self._assert_mutually_exclusive( + params_assigned_from_default, ["WARN_ERROR", "WARN_ERROR_OPTIONS"] + ) # Support lower cased access for legacy code params = set( @@ -88,3 +95,20 @@ def assign_params(ctx, params_assigned_from_default): def __str__(self) -> str: return str(pf(self.__dict__)) + + def _assert_mutually_exclusive( + self, params_assigned_from_default: Set[str], group: List[str] + ) -> None: + """ + Ensure no elements from group are simultaneously provided by a user, as inferred from params_assigned_from_default. + Raises click.UsageError if any two elements from group are simultaneously provided by a user. + """ + set_flag = None + for flag in group: + flag_set_by_user = flag.lower() not in params_assigned_from_default + if flag_set_by_user and set_flag: + raise BadOptionUsage( + flag.lower(), f"{flag.lower()}: not allowed with argument {set_flag.lower()}" + ) + elif flag_set_by_user: + set_flag = flag diff --git a/core/dbt/cli/params.py b/core/dbt/cli/params.py index c78d5b5480b..34037a1b57e 100644 --- a/core/dbt/cli/params.py +++ b/core/dbt/cli/params.py @@ -392,7 +392,7 @@ envvar="DBT_WARN_ERROR", help="If dbt would normally warn, instead raise an exception. Examples include --select that selects nothing, deprecations, configurations with no associated models, invalid test configurations, and missing sources/refs in tests.", default=None, - flag_value=True, + is_flag=True, ) warn_error_options = click.option( diff --git a/core/dbt/events/functions.py b/core/dbt/events/functions.py index 00407b538bd..e06364b390f 100644 --- a/core/dbt/events/functions.py +++ b/core/dbt/events/functions.py @@ -168,11 +168,7 @@ def msg_to_dict(msg: EventMsg) -> dict: def warn_or_error(event, node=None): - # TODO: resolve this circular import when flags.WARN_ERROR_OPTIONS is WarnErrorOptions type via click CLI. - from dbt.helper_types import WarnErrorOptions - - warn_error_options = WarnErrorOptions.from_yaml_string(flags.WARN_ERROR_OPTIONS) - if flags.WARN_ERROR or warn_error_options.includes(type(event).__name__): + if flags.WARN_ERROR or flags.WARN_ERROR_OPTIONS.includes(type(event).__name__): # TODO: resolve this circular import when at top from dbt.exceptions import EventCompilationError diff --git a/core/dbt/flags.py b/core/dbt/flags.py index e5b94c7415b..f3ddbeb49df 100644 --- a/core/dbt/flags.py +++ b/core/dbt/flags.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Optional +from dbt.helper_types import WarnErrorOptions + # PROFILES_DIR must be set before the other flags # It also gets set in main.py and in set_from_args because the rpc server # doesn't go through exactly the same main arg processing. @@ -46,7 +48,7 @@ USE_EXPERIMENTAL_PARSER = None VERSION_CHECK = None WARN_ERROR = None -WARN_ERROR_OPTIONS = None +WARN_ERROR_OPTIONS = WarnErrorOptions(include=[]) WHICH = None WRITE_JSON = None @@ -170,7 +172,13 @@ def set_from_args(args, user_config): USE_EXPERIMENTAL_PARSER = get_flag_value("USE_EXPERIMENTAL_PARSER", args, user_config) VERSION_CHECK = get_flag_value("VERSION_CHECK", args, user_config) WARN_ERROR = get_flag_value("WARN_ERROR", args, user_config) - WARN_ERROR_OPTIONS = get_flag_value("WARN_ERROR_OPTIONS", args, user_config) + + warn_error_options_str = get_flag_value("WARN_ERROR_OPTIONS", args, user_config) + from dbt.cli.option_types import WarnErrorOptionsType + + # Converting to WarnErrorOptions for consistency with dbt/cli/flags.py + WARN_ERROR_OPTIONS = WarnErrorOptionsType().convert(warn_error_options_str, None, None) + WRITE_JSON = get_flag_value("WRITE_JSON", args, user_config) _check_mutually_exclusive(["WARN_ERROR", "WARN_ERROR_OPTIONS"], args, user_config) diff --git a/core/dbt/helper_types.py b/core/dbt/helper_types.py index 84f253b00c6..77e25c68ce8 100644 --- a/core/dbt/helper_types.py +++ b/core/dbt/helper_types.py @@ -123,22 +123,6 @@ def _validate_items(self, items: List[str]): class WarnErrorOptions(IncludeExclude): - # TODO: this method can be removed once the click CLI is in use - @classmethod - def from_yaml_string(cls, warn_error_options_str: Optional[str]): - - # TODO: resolve circular import - from dbt.config.utils import parse_cli_yaml_string - - warn_error_options_str = ( - str(warn_error_options_str) if warn_error_options_str is not None else "{}" - ) - warn_error_options = parse_cli_yaml_string(warn_error_options_str, "warn-error-options") - return cls( - include=warn_error_options.get("include", []), - exclude=warn_error_options.get("exclude", []), - ) - def _validate_items(self, items: List[str]): valid_exception_names = set( [name for name, cls in dbt_event_types.__dict__.items() if isinstance(cls, type)] diff --git a/test/unit/test_flags.py b/test/unit/test_flags.py index 6f03ec22e92..36648e7b5c3 100644 --- a/test/unit/test_flags.py +++ b/test/unit/test_flags.py @@ -6,6 +6,7 @@ from dbt import flags from dbt.contracts.project import UserConfig from dbt.graph.selector_spec import IndirectSelection +from dbt.helper_types import WarnErrorOptions class TestFlags(TestCase): @@ -66,13 +67,13 @@ def test__flags(self): # warn_error_options self.user_config.warn_error_options = '{"include": "all"}' flags.set_from_args(self.args, self.user_config) - self.assertEqual(flags.WARN_ERROR_OPTIONS, '{"include": "all"}') + self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include="all")) os.environ['DBT_WARN_ERROR_OPTIONS'] = '{"include": []}' flags.set_from_args(self.args, self.user_config) - self.assertEqual(flags.WARN_ERROR_OPTIONS, '{"include": []}') + self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include=[])) setattr(self.args, 'warn_error_options', '{"include": "all"}') flags.set_from_args(self.args, self.user_config) - self.assertEqual(flags.WARN_ERROR_OPTIONS, '{"include": "all"}') + self.assertEqual(flags.WARN_ERROR_OPTIONS, WarnErrorOptions(include="all")) # cleanup os.environ.pop('DBT_WARN_ERROR_OPTIONS') delattr(self.args, 'warn_error_options') @@ -283,7 +284,7 @@ def test__flags(self): def test__flags_are_mutually_exclusive(self): # options from user config self.user_config.warn_error = False - self.user_config.warn_error_options = '{"include":"all}' + self.user_config.warn_error_options = '{"include":"all"}' with pytest.raises(ValueError): flags.set_from_args(self.args, self.user_config) #cleanup @@ -292,7 +293,7 @@ def test__flags_are_mutually_exclusive(self): # options from args setattr(self.args, 'warn_error', False) - setattr(self.args, 'warn_error_options', '{"include":"all}') + setattr(self.args, 'warn_error_options', '{"include":"all"}') with pytest.raises(ValueError): flags.set_from_args(self.args, self.user_config) # cleanup @@ -310,7 +311,7 @@ def test__flags_are_mutually_exclusive(self): # options from user config + args self.user_config.warn_error = False - setattr(self.args, 'warn_error_options', '{"include":"all}') + setattr(self.args, 'warn_error_options', '{"include":"all"}') with pytest.raises(ValueError): flags.set_from_args(self.args, self.user_config) # cleanup diff --git a/tests/unit/test_cli_flags.py b/tests/unit/test_cli_flags.py index fccf7e859df..10f83e6aee9 100644 --- a/tests/unit/test_cli_flags.py +++ b/tests/unit/test_cli_flags.py @@ -7,6 +7,7 @@ from dbt.cli.main import cli from dbt.contracts.project import UserConfig from dbt.cli.flags import Flags +from dbt.helper_types import WarnErrorOptions class TestFlags: @@ -18,6 +19,10 @@ def make_dbt_context(self, context_name: str, args: List[str]) -> click.Context: def run_context(self) -> click.Context: return self.make_dbt_context("run", ["run"]) + @pytest.fixture + def user_config(self) -> UserConfig: + return UserConfig() + def test_which(self, run_context): flags = Flags(run_context) assert flags.WHICH == "run" @@ -55,9 +60,7 @@ def test_anonymous_usage_state( flags = Flags(run_context) assert flags.ANONYMOUS_USAGE_STATS == expected_anonymous_usage_stats - def test_empty_user_config_uses_default(self, run_context): - user_config = UserConfig() - + def test_empty_user_config_uses_default(self, run_context, user_config): flags = Flags(run_context, user_config) assert flags.USE_COLORS == run_context.params["use_colors"] @@ -65,8 +68,8 @@ def test_none_user_config_uses_default(self, run_context): flags = Flags(run_context, None) assert flags.USE_COLORS == run_context.params["use_colors"] - def test_prefer_user_config_to_default(self, run_context): - user_config = UserConfig(use_colors=False) + def test_prefer_user_config_to_default(self, run_context, user_config): + user_config.use_colors = False # ensure default value is not the same as user config assert run_context.params["use_colors"] is not user_config.use_colors @@ -80,10 +83,81 @@ def test_prefer_param_value_to_user_config(self): flags = Flags(context, user_config) assert flags.USE_COLORS - def test_prefer_env_to_user_config(self, monkeypatch): - user_config = UserConfig(use_colors=False) + def test_prefer_env_to_user_config(self, monkeypatch, user_config): + user_config.use_colors = False monkeypatch.setenv("DBT_USE_COLORS", "True") context = self.make_dbt_context("run", ["run"]) flags = Flags(context, user_config) assert flags.USE_COLORS + + def test_mutually_exclusive_options_passed_separately(self): + """Assert options that are mutually exclusive can be passed separately without error""" + warn_error_context = self.make_dbt_context("run", ["--warn-error", "run"]) + + flags = Flags(warn_error_context) + assert flags.WARN_ERROR + + warn_error_options_context = self.make_dbt_context( + "run", ["--warn-error-options", '{"include": "all"}', "run"] + ) + flags = Flags(warn_error_options_context) + assert flags.WARN_ERROR_OPTIONS == WarnErrorOptions(include="all") + + def test_mutually_exclusive_options_from_cli(self): + context = self.make_dbt_context( + "run", ["--warn-error", "--warn-error-options", '{"include": "all"}', "run"] + ) + + with pytest.raises(click.BadOptionUsage): + Flags(context) + + @pytest.mark.parametrize("warn_error", [True, False]) + def test_mutually_exclusive_options_from_user_config(self, warn_error, user_config): + user_config.warn_error = warn_error + context = self.make_dbt_context( + "run", ["--warn-error-options", '{"include": "all"}', "run"] + ) + + with pytest.raises(click.BadOptionUsage): + Flags(context, user_config) + + @pytest.mark.parametrize("warn_error", ["True", "False"]) + def test_mutually_exclusive_options_from_envvar(self, warn_error, monkeypatch): + monkeypatch.setenv("DBT_WARN_ERROR", warn_error) + monkeypatch.setenv("DBT_WARN_ERROR_OPTIONS", '{"include":"all"}') + context = self.make_dbt_context("run", ["run"]) + + with pytest.raises(click.BadOptionUsage): + Flags(context) + + @pytest.mark.parametrize("warn_error", [True, False]) + def test_mutually_exclusive_options_from_cli_and_user_config(self, warn_error, user_config): + user_config.warn_error = warn_error + context = self.make_dbt_context( + "run", ["--warn-error-options", '{"include": "all"}', "run"] + ) + + with pytest.raises(click.BadOptionUsage): + Flags(context, user_config) + + @pytest.mark.parametrize("warn_error", ["True", "False"]) + def test_mutually_exclusive_options_from_cli_and_envvar(self, warn_error, monkeypatch): + monkeypatch.setenv("DBT_WARN_ERROR", warn_error) + context = self.make_dbt_context( + "run", ["--warn-error-options", '{"include": "all"}', "run"] + ) + + with pytest.raises(click.BadOptionUsage): + Flags(context) + + @pytest.mark.parametrize("warn_error", ["True", "False"]) + def test_mutually_exclusive_options_from_user_config_and_envvar( + self, user_config, warn_error, monkeypatch + ): + user_config.warn_error = warn_error + monkeypatch.setenv("DBT_WARN_ERROR_OPTIONS", '{"include": "all"}') + context = self.make_dbt_context("run", ["run"]) + + with pytest.raises(click.BadOptionUsage): + Flags(context, user_config) diff --git a/tests/unit/test_dbt_runner.py b/tests/unit/test_dbt_runner.py index 780a837c579..f1dd2708fff 100644 --- a/tests/unit/test_dbt_runner.py +++ b/tests/unit/test_dbt_runner.py @@ -18,6 +18,10 @@ def test_command_invalid_option(self, dbt: dbtRunner) -> None: with pytest.raises(dbtUsageException): dbt.invoke(["deps", "--invalid-option"]) + def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None: + with pytest.raises(dbtUsageException): + dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"]) + def test_invalid_command(self, dbt: dbtRunner) -> None: with pytest.raises(dbtUsageException): dbt.invoke(["invalid-command"])