From 6cc3e15c1ff845f04fa41468f4cbb0e6b394c6dd Mon Sep 17 00:00:00 2001 From: Laura Couto Date: Tue, 21 Jan 2025 12:36:49 -0300 Subject: [PATCH] Switch yes/no CLI prompts to click validation Signed-off-by: Laura Couto --- kedro/framework/cli/starters.py | 37 ++++++++++++++++++---------- tests/framework/cli/test_starters.py | 8 ++++-- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/kedro/framework/cli/starters.py b/kedro/framework/cli/starters.py index 5b1e087a60..88c580be3b 100644 --- a/kedro/framework/cli/starters.py +++ b/kedro/framework/cli/starters.py @@ -196,6 +196,14 @@ def _validate_flag_inputs(flag_inputs: dict[str, Any]) -> None: ) +def _validate_yes_no(ctx, param, value) -> None: + if value is None: + return None + if not re.match(r"(?i)^\s*(y|yes|n|no)\s*$", value, flags=re.X): + raise click.BadParameter(f"'{value}' is an invalid value for {param}. It must contain only y, n, YES, or NO (case insensitive).") + return value + + def _validate_input_with_regex_pattern(pattern_name: str, input: str) -> None: VALIDATION_PATTERNS = { "yes_no": { @@ -324,8 +332,8 @@ def starter() -> None: "selected_tools", help=TOOLS_ARG_HELP, ) -@click.option("--example", "-e", "example_pipeline", help=EXAMPLE_ARG_HELP) -@click.option("--telemetry", "-tc", "telemetry_consent", help=TELEMETRY_ARG_HELP) +@click.option("--example", "-e", "example_pipeline", help=EXAMPLE_ARG_HELP, callback=_validate_yes_no) +@click.option("--telemetry", "-tc", "telemetry_consent", help=TELEMETRY_ARG_HELP, callback=_validate_yes_no) def new( # noqa: PLR0913 config_path: str, starter_alias: str, @@ -412,7 +420,6 @@ def new( # noqa: PLR0913 ) if telemetry_consent is not None: - _validate_input_with_regex_pattern("yes_no", telemetry_consent) telemetry_consent = ( "true" if _parse_yes_no_to_bool(telemetry_consent) else "false" ) @@ -520,7 +527,6 @@ def _get_prompts_required_and_clear_from_CLI_provided( del prompts_required["project_name"] if example_pipeline is not None: - _validate_input_with_regex_pattern("yes_no", example_pipeline) del prompts_required["example_pipeline"] return prompts_required @@ -728,7 +734,7 @@ def _fetch_validate_parse_config_from_file( ) example_pipeline = config.get("example_pipeline", "no") - _validate_input_with_regex_pattern("yes_no", example_pipeline) + _validate_yes_no(None, "example pipeline", example_pipeline) config["example_pipeline"] = str(_parse_yes_no_to_bool(example_pipeline)) tools_short_names = config.get("tools", "none").lower() @@ -765,16 +771,15 @@ def _fetch_validate_parse_config_from_user_prompts( default_value = cookiecutter_context.get(variable_name) or "" # read the user's input for the variable - user_input = click.prompt( + prompt.user_input = click.prompt( str(prompt), default=default_value, show_default=True, type=str, ).strip() - if user_input: - prompt.validate(user_input) - config[variable_name] = user_input + if prompt.user_input: + config[variable_name] = prompt.user_input if "tools" in config: # convert tools input to list of numbers and validate @@ -1024,13 +1029,19 @@ def __str__(self) -> str: prompt_text = "\n".join(str(line).strip() for line in prompt_lines) return f"\n{prompt_text}\n" - def validate(self, user_input: str) -> None: - """Validate a given prompt value against the regex validator""" - if self.regexp and not re.match(self.regexp, user_input): - message = f"'{user_input}' is an invalid value for {(self.title).lower()}." + @property + def user_input(self): + return self._user_input + + @user_input.setter + def user_input(self, input): + """Validate and set the user input.""" + if self.regexp and not re.match(self.regexp, input): + message = f"'{input}' is an invalid value for {(self.title).lower()}." click.secho(message, fg="red", err=True) click.secho(self.error_message, fg="red", err=True) sys.exit(1) + self._user_input = input def _remove_readonly( diff --git a/tests/framework/cli/test_starters.py b/tests/framework/cli/test_starters.py index a1852c75af..f6b65a6c0c 100644 --- a/tests/framework/cli/test_starters.py +++ b/tests/framework/cli/test_starters.py @@ -1676,10 +1676,14 @@ def test_flag_value_is_invalid(self, fake_kedro_cli): ) repo_name = "new-kedro-project" - assert result.exit_code == 1 + assert result.exit_code == 2 assert ( - "'wrong' is an invalid value for example pipeline. It must contain only y, n, YES, or NO (case insensitive)." + "'wrong' is an invalid value" + in result.output + ) + assert ( + " It must contain only y, n, YES, or NO (case insensitive)." in result.output )