Skip to content

Commit

Permalink
Switch yes/no CLI prompts to click validation
Browse files Browse the repository at this point in the history
Signed-off-by: Laura Couto <laurarccouto@gmail.com>
  • Loading branch information
lrcouto committed Jan 21, 2025
1 parent da38e1a commit 6cc3e15
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
37 changes: 24 additions & 13 deletions kedro/framework/cli/starters.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ def _validate_flag_inputs(flag_inputs: dict[str, Any]) -> None:
)


def _validate_yes_no(ctx, param, value) -> None:
if value is None:
return None
if not re.match(r"(?i)^\s*(y|yes|n|no)\s*$", value, flags=re.X):
raise click.BadParameter(f"'{value}' is an invalid value for {param}. It must contain only y, n, YES, or NO (case insensitive).")
return value


def _validate_input_with_regex_pattern(pattern_name: str, input: str) -> None:
VALIDATION_PATTERNS = {
"yes_no": {
Expand Down Expand Up @@ -324,8 +332,8 @@ def starter() -> None:
"selected_tools",
help=TOOLS_ARG_HELP,
)
@click.option("--example", "-e", "example_pipeline", help=EXAMPLE_ARG_HELP)
@click.option("--telemetry", "-tc", "telemetry_consent", help=TELEMETRY_ARG_HELP)
@click.option("--example", "-e", "example_pipeline", help=EXAMPLE_ARG_HELP, callback=_validate_yes_no)
@click.option("--telemetry", "-tc", "telemetry_consent", help=TELEMETRY_ARG_HELP, callback=_validate_yes_no)
def new( # noqa: PLR0913
config_path: str,
starter_alias: str,
Expand Down Expand Up @@ -412,7 +420,6 @@ def new( # noqa: PLR0913
)

if telemetry_consent is not None:
_validate_input_with_regex_pattern("yes_no", telemetry_consent)
telemetry_consent = (
"true" if _parse_yes_no_to_bool(telemetry_consent) else "false"
)
Expand Down Expand Up @@ -520,7 +527,6 @@ def _get_prompts_required_and_clear_from_CLI_provided(
del prompts_required["project_name"]

if example_pipeline is not None:
_validate_input_with_regex_pattern("yes_no", example_pipeline)
del prompts_required["example_pipeline"]

return prompts_required
Expand Down Expand Up @@ -728,7 +734,7 @@ def _fetch_validate_parse_config_from_file(
)

example_pipeline = config.get("example_pipeline", "no")
_validate_input_with_regex_pattern("yes_no", example_pipeline)
_validate_yes_no(None, "example pipeline", example_pipeline)
config["example_pipeline"] = str(_parse_yes_no_to_bool(example_pipeline))

tools_short_names = config.get("tools", "none").lower()
Expand Down Expand Up @@ -765,16 +771,15 @@ def _fetch_validate_parse_config_from_user_prompts(
default_value = cookiecutter_context.get(variable_name) or ""

# read the user's input for the variable
user_input = click.prompt(
prompt.user_input = click.prompt(
str(prompt),
default=default_value,
show_default=True,
type=str,
).strip()

if user_input:
prompt.validate(user_input)
config[variable_name] = user_input
if prompt.user_input:
config[variable_name] = prompt.user_input

if "tools" in config:
# convert tools input to list of numbers and validate
Expand Down Expand Up @@ -1024,13 +1029,19 @@ def __str__(self) -> str:
prompt_text = "\n".join(str(line).strip() for line in prompt_lines)
return f"\n{prompt_text}\n"

def validate(self, user_input: str) -> None:
"""Validate a given prompt value against the regex validator"""
if self.regexp and not re.match(self.regexp, user_input):
message = f"'{user_input}' is an invalid value for {(self.title).lower()}."
@property
def user_input(self):
return self._user_input

@user_input.setter
def user_input(self, input):
"""Validate and set the user input."""
if self.regexp and not re.match(self.regexp, input):
message = f"'{input}' is an invalid value for {(self.title).lower()}."
click.secho(message, fg="red", err=True)
click.secho(self.error_message, fg="red", err=True)
sys.exit(1)
self._user_input = input


def _remove_readonly(
Expand Down
8 changes: 6 additions & 2 deletions tests/framework/cli/test_starters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1676,10 +1676,14 @@ def test_flag_value_is_invalid(self, fake_kedro_cli):
)

repo_name = "new-kedro-project"
assert result.exit_code == 1
assert result.exit_code == 2

assert (
"'wrong' is an invalid value for example pipeline. It must contain only y, n, YES, or NO (case insensitive)."
"'wrong' is an invalid value"
in result.output
)
assert (
" It must contain only y, n, YES, or NO (case insensitive)."
in result.output
)

Expand Down

0 comments on commit 6cc3e15

Please sign in to comment.