From 51642a37ce6a830c8361f7be83be9cfd555c9df5 Mon Sep 17 00:00:00 2001 From: Jeremy Cohen Date: Wed, 8 Nov 2023 17:41:41 -0500 Subject: [PATCH] Fix mutually exclusive flag detection --- core/dbt/cli/flags.py | 29 +++++++++++++++---- .../functional/dbt_runner/test_dbt_runner.py | 2 ++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/core/dbt/cli/flags.py b/core/dbt/cli/flags.py index 0055032e146..2d7f108dc9f 100644 --- a/core/dbt/cli/flags.py +++ b/core/dbt/cli/flags.py @@ -107,6 +107,7 @@ def _get_params_by_source(ctx: Context, source_type: ParameterSource): def _assign_params( ctx: Context, params_assigned_from_default: set, + params_assigned_from_user: set, deprecated_env_vars: Dict[str, Callable], ): """Recursively adds all click params to flag object""" @@ -173,15 +174,30 @@ def _assign_params( object.__setattr__(self, flag_name, param_value) # Track default assigned params. - if is_default: + # For flags that are accepted at both 'parent' and 'child' levels, + # we need to track user-provided and default values across both, + # to support detection of mutually exclusive flags later on. + if not is_default: + params_assigned_from_user.add(param_name) + if param_name in params_assigned_from_default: + params_assigned_from_default.remove(param_name) + if is_default and param_name not in params_assigned_from_user: params_assigned_from_default.add(param_name) if ctx.parent: - _assign_params(ctx.parent, params_assigned_from_default, deprecated_env_vars) + _assign_params( + ctx.parent, + params_assigned_from_default, + params_assigned_from_user, + deprecated_env_vars, + ) + params_assigned_from_user = set() # type: Set[str] params_assigned_from_default = set() # type: Set[str] deprecated_env_vars: Dict[str, Callable] = {} - _assign_params(ctx, params_assigned_from_default, deprecated_env_vars) + _assign_params( + ctx, params_assigned_from_default, params_assigned_from_user, deprecated_env_vars + ) # Set deprecated_env_var_warnings to be fired later after events have been init. object.__setattr__( @@ -198,7 +214,10 @@ def _assign_params( invoked_subcommand.ignore_unknown_options = True invoked_subcommand_ctx = invoked_subcommand.make_context(None, sys.argv) _assign_params( - invoked_subcommand_ctx, params_assigned_from_default, deprecated_env_vars + invoked_subcommand_ctx, + params_assigned_from_default, + params_assigned_from_user, + deprecated_env_vars, ) if not user_config: @@ -350,8 +369,6 @@ def add_fn(x): elif (v is None or v is False) and k not in ( # These are None by default but they do not support --no-{flag} "defer_state", - "warn_error", - "warn_error_options", "log_format", ): add_fn(f"--no-{spinal_cased}") diff --git a/tests/functional/dbt_runner/test_dbt_runner.py b/tests/functional/dbt_runner/test_dbt_runner.py index 20041f05952..40edcccae8d 100644 --- a/tests/functional/dbt_runner/test_dbt_runner.py +++ b/tests/functional/dbt_runner/test_dbt_runner.py @@ -23,6 +23,8 @@ def test_command_invalid_option(self, dbt: dbtRunner) -> None: def test_command_mutually_exclusive_option(self, dbt: dbtRunner) -> None: res = dbt.invoke(["--warn-error", "--warn-error-options", '{"include": "all"}', "deps"]) assert type(res.exception) == DbtUsageException + res = dbt.invoke(["deps", "--warn-error", "--warn-error-options", '{"include": "all"}']) + assert type(res.exception) == DbtUsageException def test_invalid_command(self, dbt: dbtRunner) -> None: res = dbt.invoke(["invalid-command"])