Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Feb 3, 2025
1 parent 9e017cd commit 9e77cbd
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 84 deletions.
76 changes: 74 additions & 2 deletions src/zenml/config/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Class for compiling ZenML pipelines into a serializable format."""

import copy
import os
import string
from typing import (
TYPE_CHECKING,
Expand All @@ -36,6 +37,11 @@
StepConfigurationUpdate,
StepSpec,
)
from zenml.constants import (
ENV_ZENML_ACTIVE_STACK_ID,
ENV_ZENML_ACTIVE_WORKSPACE_ID,
ENV_ZENML_STORE_PREFIX,
)
from zenml.environment import get_run_environment_dict
from zenml.exceptions import StackValidationError
from zenml.models import PipelineDeploymentBase
Expand All @@ -50,6 +56,8 @@

from zenml.logger import get_logger

ENVIRONMENT_VARIABLE_PREFIX = "__ZENML__"

logger = get_logger(__file__)


Expand Down Expand Up @@ -104,13 +112,20 @@ def compile(
pipeline.configuration.substitutions,
)

pipeline_environment = finalize_environment_variables(
pipeline.configuration.environment
)
pipeline_settings = self._filter_and_validate_settings(
settings=pipeline.configuration.settings,
configuration_level=ConfigurationLevel.PIPELINE,
stack=stack,
)
with pipeline.__suppress_configure_warnings__():
pipeline.configure(settings=pipeline_settings, merge=False)
pipeline.configure(
environment=pipeline_environment,
settings=pipeline_settings,
merge=False,
)

settings_to_passdown = {
key: settings
Expand All @@ -121,6 +136,7 @@ def compile(
steps = {
invocation_id: self._compile_step_invocation(
invocation=invocation,
pipeline_environment=pipeline_environment,
pipeline_settings=settings_to_passdown,
pipeline_extra=pipeline.configuration.extra,
stack=stack,
Expand Down Expand Up @@ -210,6 +226,7 @@ def _apply_run_configuration(
enable_artifact_metadata=config.enable_artifact_metadata,
enable_artifact_visualization=config.enable_artifact_visualization,
enable_step_logs=config.enable_step_logs,
environment=config.environment,
settings=config.settings,
tags=config.tags,
extra=config.extra,
Expand Down Expand Up @@ -427,6 +444,7 @@ def _get_step_spec(
def _compile_step_invocation(
self,
invocation: "StepInvocation",
pipeline_environment: Optional[Dict[str, Any]],
pipeline_settings: Dict[str, "BaseSettings"],
pipeline_extra: Dict[str, Any],
stack: "Stack",
Expand All @@ -438,7 +456,9 @@ def _compile_step_invocation(
Args:
invocation: The step invocation to compile.
pipeline_settings: settings configured on the
pipeline_environment: Environment variables configured for the
pipeline.
pipeline_settings: Settings configured on the
pipeline of the step.
pipeline_extra: Extra values configured on the pipeline of the step.
stack: The stack on which the pipeline will be run.
Expand All @@ -463,6 +483,9 @@ def _compile_step_invocation(
step.configuration.settings, stack=stack
)
step_spec = self._get_step_spec(invocation=invocation)
step_environment = finalize_environment_variables(
step.configuration.environment
)
step_settings = self._filter_and_validate_settings(
settings=step.configuration.settings,
configuration_level=ConfigurationLevel.STEP,
Expand All @@ -473,13 +496,15 @@ def _compile_step_invocation(
step_on_success_hook_source = step.configuration.success_hook_source

step.configure(
environment=pipeline_environment,
settings=pipeline_settings,
extra=pipeline_extra,
on_failure=pipeline_failure_hook_source,
on_success=pipeline_success_hook_source,
merge=False,
)
step.configure(
environment=step_environment,
settings=step_settings,
extra=step_extra,
on_failure=step_on_failure_hook_source,
Expand Down Expand Up @@ -635,3 +660,50 @@ def convert_component_shortcut_settings_keys(
)

settings[key] = component_settings


def finalize_environment_variables(
environment: Dict[str, Any],
) -> Dict[str, str]:
"""Finalize the user environment variables.
This function adds all __ZENML__ prefixed environment variables from the
local client environment to the explicit user-defined variables.
Args:
environment: The explicit user-defined environment variables.
Returns:
The finalized user environment variables.
"""
environment = {key: str(value) for key, value in environment.items()}

for key, value in os.environ.items():
if key.startswith(ENVIRONMENT_VARIABLE_PREFIX):
key_without_prefix = key[len(ENVIRONMENT_VARIABLE_PREFIX) :]

if (
key_without_prefix in environment
and value != environment[key_without_prefix]
):
logger.warning(
"Got multiple values for environment variable `%s`.",
key_without_prefix,
)
else:
environment[key_without_prefix] = value

finalized_env = {}

for key, value in environment.items():
if key.upper().startswith(ENV_ZENML_STORE_PREFIX) or key.upper() in [
ENV_ZENML_ACTIVE_WORKSPACE_ID,
ENV_ZENML_ACTIVE_STACK_ID,
]:
logger.warning(
"Not allowed to set `%s` config environment variable.", key
)
continue
finalized_env[key] = str(value)

return finalized_env
1 change: 1 addition & 0 deletions src/zenml/config/pipeline_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
enable_artifact_metadata: Optional[bool] = None
enable_artifact_visualization: Optional[bool] = None
enable_step_logs: Optional[bool] = None
environment: Dict[str, Any] = {}
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
tags: Optional[List[str]] = None
extra: Dict[str, Any] = {}
Expand Down
1 change: 1 addition & 0 deletions src/zenml/config/pipeline_run_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class PipelineRunConfiguration(
default=None, union_mode="left_to_right"
)
steps: Dict[str, StepConfigurationUpdate] = {}
environment: Dict[str, Any] = {}
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
tags: Optional[List[str]] = None
extra: Dict[str, Any] = {}
Expand Down
1 change: 1 addition & 0 deletions src/zenml/config/step_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class StepConfigurationUpdate(StrictBaseModel):
step_operator: Optional[str] = None
experiment_tracker: Optional[str] = None
parameters: Dict[str, Any] = {}
environment: Dict[str, Any] = {}
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
extra: Dict[str, Any] = {}
failure_hook_source: Optional[SourceWithValidator] = None
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/orchestrators/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def generate_cache_key(
hash_.update(key.encode())
hash_.update(str(value).encode())

# User-defined environment variables
for key, value in sorted(step.config.environment.items()):
hash_.update(key.encode())
hash_.update(str(value).encode())

return hash_.hexdigest()


Expand Down
151 changes: 80 additions & 71 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
parse_return_type_annotations,
resolve_type_annotation,
)
from zenml.utils import materializer_utils, source_utils, string_utils
from zenml.utils import (
env_utils,
materializer_utils,
source_utils,
string_utils,
)
from zenml.utils.typing_utils import get_origin, is_union

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,86 +188,90 @@ def run(
)

step_failed = False
try:
return_values = step_instance.call_entrypoint(
**function_params
)
except BaseException as step_exception: # noqa: E722
step_failed = True
if not handle_bool_env_var(
ENV_ZENML_IGNORE_FAILURE_HOOK, False
):
if (
failure_hook_source
:= self.configuration.failure_hook_source
):
logger.info("Detected failure hook. Running...")
self.load_and_run_hook(
failure_hook_source,
step_exception=step_exception,
)
raise
finally:
with env_utils.temporary_environment(step_run.config.environment):
try:
step_run_metadata = self._stack.get_step_run_metadata(
info=step_run_info,
)
publish_step_run_metadata(
step_run_id=step_run_info.step_run_id,
step_run_metadata=step_run_metadata,
)
self._stack.cleanup_step_run(
info=step_run_info, step_failed=step_failed
return_values = step_instance.call_entrypoint(
**function_params
)
if not step_failed:
except BaseException as step_exception: # noqa: E722
step_failed = True
if not handle_bool_env_var(
ENV_ZENML_IGNORE_FAILURE_HOOK, False
):
if (
success_hook_source
:= self.configuration.success_hook_source
failure_hook_source
:= self.configuration.failure_hook_source
):
logger.info("Detected success hook. Running...")
logger.info("Detected failure hook. Running...")
self.load_and_run_hook(
success_hook_source,
step_exception=None,
failure_hook_source,
step_exception=step_exception,
)

# Store and publish the output artifacts of the step function.
output_data = self._validate_outputs(
return_values, output_annotations
)
artifact_metadata_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
raise
finally:
try:
step_run_metadata = self._stack.get_step_run_metadata(
info=step_run_info,
)
artifact_visualization_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
publish_step_run_metadata(
step_run_id=step_run_info.step_run_id,
step_run_metadata=step_run_metadata,
)
output_artifacts = self._store_output_artifacts(
output_data=output_data,
output_artifact_uris=output_artifact_uris,
output_materializers=output_materializers,
output_annotations=output_annotations,
artifact_metadata_enabled=artifact_metadata_enabled,
artifact_visualization_enabled=artifact_visualization_enabled,
self._stack.cleanup_step_run(
info=step_run_info, step_failed=step_failed
)

if (
model_version := step_run.model_version
or pipeline_run.model_version
):
from zenml.orchestrators import step_run_utils

step_run_utils.link_output_artifacts_to_model_version(
artifacts={
k: [v] for k, v in output_artifacts.items()
},
model_version=model_version,
if not step_failed:
if (
success_hook_source
:= self.configuration.success_hook_source
):
logger.info(
"Detected success hook. Running..."
)
self.load_and_run_hook(
success_hook_source,
step_exception=None,
)

# Store and publish the output artifacts of the step function.
output_data = self._validate_outputs(
return_values, output_annotations
)
finally:
step_context._cleanup_registry.execute_callbacks(
raise_on_exception=False
)
StepContext._clear() # Remove the step context singleton
artifact_metadata_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
)
artifact_visualization_enabled = is_setting_enabled(
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
)
output_artifacts = self._store_output_artifacts(
output_data=output_data,
output_artifact_uris=output_artifact_uris,
output_materializers=output_materializers,
output_annotations=output_annotations,
artifact_metadata_enabled=artifact_metadata_enabled,
artifact_visualization_enabled=artifact_visualization_enabled,
)

if (
model_version := step_run.model_version
or pipeline_run.model_version
):
from zenml.orchestrators import step_run_utils

step_run_utils.link_output_artifacts_to_model_version(
artifacts={
k: [v]
for k, v in output_artifacts.items()
},
model_version=model_version,
)
finally:
step_context._cleanup_registry.execute_callbacks(
raise_on_exception=False
)
StepContext._clear() # Remove the step context singleton

# Update the status and output artifacts of the step run.
output_artifact_ids = {
Expand Down
Loading

0 comments on commit 9e77cbd

Please sign in to comment.