Skip to content

Commit

Permalink
refactor(airflow): Remove bootstrap_project (#599)
Browse files Browse the repository at this point in the history
* Remove bootstrap_project

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Refactor

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Fix tests

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Print for debugging

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* add resolve

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Update tests

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
Signed-off-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com>
  • Loading branch information
ankatiyar authored Mar 21, 2024
1 parent 360457e commit 9acc3a5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 40 deletions.
20 changes: 11 additions & 9 deletions kedro-airflow/kedro_airflow/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from kedro.framework.context import KedroContext
from kedro.framework.project import pipelines
from kedro.framework.session import KedroSession
from kedro.framework.startup import ProjectMetadata, bootstrap_project
from kedro.framework.startup import ProjectMetadata
from slugify import slugify

from kedro_airflow.grouping import group_memory_nodes
Expand Down Expand Up @@ -100,8 +100,8 @@ def _get_pipeline_config(config_airflow: dict, params: dict, pipeline_name: str)
"-t",
"--target-dir",
"target_path",
type=click.Path(writable=True, resolve_path=True, file_okay=False),
default="./airflow_dags/",
type=click.Path(writable=True, resolve_path=False, file_okay=False),
default="airflow_dags/",
help="The directory path to store the generated Airflow dags",
)
@click.option(
Expand Down Expand Up @@ -152,10 +152,7 @@ def create( # noqa: PLR0913, PLR0912
raise click.BadParameter(
"The `--all` and `--pipeline` option are mutually exclusive."
)

project_path = Path.cwd().resolve()
bootstrap_project(project_path)
with KedroSession.create(project_path=project_path, env=env) as session:
with KedroSession.create(project_path=metadata.project_path, env=env) as session:
context = session.load_context()
config_airflow = _load_config(context)

Expand All @@ -165,10 +162,15 @@ def create( # noqa: PLR0913, PLR0912
jinja_env.filters["slugify"] = slugify
template = jinja_env.get_template(jinja_file.name)

dags_folder = Path(target_path)
dags_folder = (
Path(target_path)
if Path(target_path).is_absolute()
else metadata.project_path / Path(target_path)
)

# Ensure that the DAGs folder exists
dags_folder.mkdir(parents=True, exist_ok=True)
secho(f"Location of the Airflow DAG folder: {target_path!s}", fg="green")
secho(f"Location of the Airflow DAG folder: {dags_folder!s}", fg="green")

package_name = metadata.package_name

Expand Down
17 changes: 4 additions & 13 deletions kedro-airflow/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from shutil import copyfile

from click.testing import CliRunner
from kedro import __version__ as kedro_version
from kedro.framework.cli.starters import create_cli as kedro_cli
from kedro.framework.startup import ProjectMetadata
from kedro.framework.startup import bootstrap_project
from pytest import fixture


Expand Down Expand Up @@ -99,14 +98,6 @@ def register_pipelines():
@fixture(scope="session")
def metadata(kedro_project):
# cwd() depends on ^ the isolated filesystem, created by CliRunner()
project_path = kedro_project
return ProjectMetadata(
source_dir=project_path / "src",
config_file=project_path / "pyproject.toml",
package_name="hello_world",
project_name="Hello world !!!",
kedro_init_version=kedro_version,
project_path=project_path,
tools=["None"],
example_pipeline="No",
)
project_path = kedro_project.resolve()
metadata = bootstrap_project(project_path)
return metadata
35 changes: 17 additions & 18 deletions kedro-airflow/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
"dag_name,pipeline_name,command",
[
# Test normal execution
("hello_world", "__default__", ["airflow", "create"]),
("fake_project", "__default__", ["airflow", "create"]),
# Test execution with alternate pipeline name
("hello_world", "ds", ["airflow", "create", "--pipeline", "ds"]),
# Test with grouping
("hello_world", "__default__", ["airflow", "create", "--group-in-memory"]),
("fake_project", "ds", ["airflow", "create", "--pipeline", "ds"]),
# # Test with grouping
("fake_project", "__default__", ["airflow", "create", "--group-in-memory"]),
],
)
def test_create_airflow_dag(dag_name, pipeline_name, command, cli_runner, metadata):
"""Check the generation and validity of a simple Airflow DAG."""
dag_file = (
Path.cwd()
metadata.project_path
/ "airflow_dags"
/ (
f"{dag_name}_dag.py"
Expand Down Expand Up @@ -52,7 +52,7 @@ def _create_kedro_airflow_yml(file_name: Path, content: dict[str, Any]):

def test_airflow_config_params(cli_runner, metadata):
"""Check if config variables are picked up"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "airflow_params.j2"
content = "{{ owner | default('hello')}}"

Expand All @@ -72,7 +72,7 @@ def test_airflow_config_params(cli_runner, metadata):

def test_airflow_config_params_cli(cli_runner, metadata):
"""Check if config variables are picked up"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "airflow_params.j2"
content = "{{ owner | default('hello')}}"

Expand All @@ -92,7 +92,7 @@ def test_airflow_config_params_cli(cli_runner, metadata):

def test_airflow_config_params_from_config(cli_runner, metadata):
"""Check if config variables are picked up"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "airflow_params.j2"
content = "{{ owner | default('hello')}}"

Expand Down Expand Up @@ -128,7 +128,7 @@ def test_airflow_config_params_from_config(cli_runner, metadata):

def test_airflow_config_params_from_config_non_default(cli_runner, metadata):
"""Check if config variables are picked up"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "airflow_params.j2"
content = "{{ owner | default('hello')}}"
default_content = "hello"
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_airflow_config_params_from_config_non_default(cli_runner, metadata):

def test_airflow_config_params_env(cli_runner, metadata):
"""Check if config variables are picked up"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "airflow_params.j2"
content = "{{ owner | default('hello')}}"

Expand All @@ -185,7 +185,7 @@ def test_airflow_config_params_env(cli_runner, metadata):

def test_airflow_config_params_custom_pipeline(cli_runner, metadata):
"""Check if config variables are picked up"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "airflow_params.j2"
content = "{{ owner | default('hello')}}"

Expand Down Expand Up @@ -213,7 +213,7 @@ def _create_kedro_airflow_jinja_template(path: Path, name: str, content: str):

def test_custom_template_exists(cli_runner, metadata):
"""Test execution with different dir and filename for Jinja2 Template"""
dag_name = "hello_world"
dag_name = "fake_project"
template_name = "custom_template.j2"
command = ["airflow", "create", "-j", template_name]
content = "print('my custom dag')"
Expand Down Expand Up @@ -252,7 +252,7 @@ def _kedro_remove_env(project_root: Path, env: str):

def test_create_airflow_dag_env_parameter_exists(cli_runner, metadata):
"""Test the `env` parameter"""
dag_name = "hello_world"
dag_name = "fake_project"
command = ["airflow", "create", "--env", "remote"]

_kedro_create_env(Path.cwd(), "remote")
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_create_airflow_dag_tags_parameter_exists(
tags, expected_airflow_dags, unexpected_airflow_dags, cli_runner, metadata
):
"""Test the `tags` parameter"""
dag_name = "hello_world"
dag_name = "fake_project"
command = ["airflow", "create", "--env", "remote"] + tags

_kedro_create_env(Path.cwd(), "remote")
Expand Down Expand Up @@ -329,14 +329,13 @@ def test_create_airflow_all_dags(cli_runner, metadata):
result = cli_runner.invoke(commands, command, obj=metadata)

assert result.exit_code == 0, (result.exit_code, result.stdout)
print(result.stdout)

for dag_name, pipeline_name in [
("hello_world", "__default__"),
("hello_world", "ds"),
("fake_project", "__default__"),
("fake_project", "ds"),
]:
dag_file = (
Path.cwd()
metadata.project_path
/ "airflow_dags"
/ (
f"{dag_name}_dag.py"
Expand Down

0 comments on commit 9acc3a5

Please sign in to comment.