Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused tfx kubeflow code #239

Merged
merged 1 commit into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zenml/integrations/kubeflow/container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def main() -> None:

metadata_store = Repository().get_active_stack().metadata_store
if isinstance(metadata_store, KubeflowMetadataStore):
# setup the metadata connection so it connects to the internal kubeflow
# set up the metadata connection so it connects to the internal kubeflow
# mysql database
connection_config = _get_grpc_metadata_connection_config()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
"""
from typing import Dict, List, Set

from absl import logging
from google.protobuf import json_format
from kfp import dsl
from kubernetes import client as k8s_client
from tfx.dsl.components.base import base_node as tfx_base_node
from tfx.orchestration import data_types
from tfx.orchestration import pipeline as tfx_pipeline
from tfx.proto.orchestration import pipeline_pb2

from zenml.artifact_stores.local_artifact_store import LocalArtifactStore
Expand All @@ -37,9 +35,11 @@
from zenml.integrations.kubeflow.orchestrators import kubeflow_utils as utils
from zenml.metadata.sqlite_metadata_wrapper import SQLiteMetadataStore

_COMMAND = ["python", "-m", "zenml.integrations.kubeflow.container_entrypoint"]

_WORKFLOW_ID_KEY = "WORKFLOW_ID"
CONTAINER_ENTRYPOINT_COMMAND = [
"python",
"-m",
"zenml.integrations.kubeflow.container_entrypoint",
]


def _encode_runtime_parameter(param: data_types.RuntimeParameter) -> str:
Expand All @@ -65,7 +65,6 @@ def __init__(
self,
component: tfx_base_node.BaseNode,
depends_on: Set[dsl.ContainerOp],
pipeline: tfx_pipeline.Pipeline,
image: str,
tfx_ir: pipeline_pb2.Pipeline, # type: ignore[valid-type]
pod_labels_to_attach: Dict[str, str],
Expand All @@ -81,7 +80,6 @@ def __init__(
component: The logical TFX component to wrap.
depends_on: The set of upstream KFP ContainerOp components that this
component will depend on.
pipeline: The logical TFX pipeline to which this component belongs.
image: The container image to use for this component.
tfx_ir: The TFX intermedia representation of the pipeline.
pod_labels_to_attach: Dict of pod labels to attach to the GKE pod.
Expand Down Expand Up @@ -133,7 +131,7 @@ def __init__(

self.container_op = dsl.ContainerOp(
name=component.id,
command=_COMMAND,
command=CONTAINER_ENTRYPOINT_COMMAND,
image=image,
arguments=arguments,
output_artifact_paths={
Expand All @@ -142,44 +140,14 @@ def __init__(
pvolumes=volumes,
)

logging.info(
"Adding upstream dependencies for component %s",
self.container_op.name,
)
for op in depends_on:
logging.info(" -> Component: %s", op.name)
self.container_op.after(op)

# TODO(b/140172100): Document the use of additional_pipeline_args.
if _WORKFLOW_ID_KEY in pipeline.additional_pipeline_args:
# Allow overriding pipeline's run_id externally, primarily for testing.
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=_WORKFLOW_ID_KEY,
value=pipeline.additional_pipeline_args[_WORKFLOW_ID_KEY],
)
)
else:
# Add the Argo workflow ID to the container's environment variable so it
# can be used to uniquely place pipeline outputs under the pipeline_root.
field_path = "metadata.labels['workflows.argoproj.io/workflow']"
self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=_WORKFLOW_ID_KEY,
value_from=k8s_client.V1EnvVarSource(
field_ref=k8s_client.V1ObjectFieldSelector(
field_path=field_path
)
),
)
)

self.container_op.container.add_env_variable(
k8s_client.V1EnvVar(
name=ENV_ZENML_PREVENT_PIPELINE_EXECUTION, value="True"
)
)

if pod_labels_to_attach:
for k, v in pod_labels_to_attach.items():
self.container_op.add_pod_label(k, v)
for k, v in pod_labels_to_attach.items():
self.container_op.add_pod_label(k, v)
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _mount_config_map_op(config_map_name: str) -> OpFunc:
"""

def mount_config_map(container_op: dsl.ContainerOp) -> None:
"""Mounts all key-value pairs found in the named Kubernetes ConfigMap."""
"""Mounts all key-value pairs found in the Kubernetes ConfigMap."""
config_map_ref = k8s_client.V1ConfigMapEnvSource(
name=config_map_name, optional=True
)
Expand Down Expand Up @@ -237,10 +237,9 @@ def __init__(
str, List[data_types.RuntimeParameter]
] = collections.defaultdict(list)
self._deduped_parameter_names: Set[str] = set()
if pod_labels_to_attach is None:
self._pod_labels_to_attach = get_default_pod_labels()
else:
self._pod_labels_to_attach = pod_labels_to_attach
self._pod_labels_to_attach = (
pod_labels_to_attach or get_default_pod_labels()
)

def _parse_parameter_from_component(
self, component: tfx_base_component.BaseComponent
Expand All @@ -266,11 +265,6 @@ def _parse_parameter_from_component(
self._params_by_component_id[component.id].append(parameter)
if parameter.name not in self._deduped_parameter_names:
self._deduped_parameter_names.add(parameter.name)
# TODO(b/178436919): Create a test to cover default value rendering
# and move the external code reference over there.
# The default needs to be serialized then passed to dsl.PipelineParam.
# See
# https://github.com/kubeflow/pipelines/blob/f65391309650fdc967586529e79af178241b4c2c/sdk/python/kfp/dsl/_pipeline_param.py#L154
dsl_parameter = dsl.PipelineParam(
name=parameter.name, value=str(parameter.default)
)
Expand All @@ -295,11 +289,12 @@ def _construct_pipeline_graph(
component_to_kfp_op: Dict[base_node.BaseNode, dsl.ContainerOp] = {}
tfx_ir = self._generate_tfx_ir(pipeline)

# Assumption: There is a partial ordering of components in the list, i.e.,
# if component A depends on component B and C, then A appears after B and C
# in the list.
# Assumption: There is a partial ordering of components in the list,
# i.e. if component A depends on component B and C, then A appears
# after B and C in the list.
for component in pipeline.components:
# Keep track of the set of upstream dsl.ContainerOps for this component.
# Keep track of the set of upstream dsl.ContainerOps for this
# component.
depends_on = set()

for upstream_component in component.upstream_nodes:
Expand All @@ -324,7 +319,6 @@ def _construct_pipeline_graph(
step_function_name=component.id,
component=component,
depends_on=depends_on,
pipeline=pipeline,
image=self._kubeflow_config.image,
pod_labels_to_attach=self._pod_labels_to_attach,
tfx_ir=tfx_node_ir,
Expand Down Expand Up @@ -392,14 +386,12 @@ def run(self, pipeline: tfx_pipeline.Pipeline) -> None:
)

def _construct_pipeline() -> None:
"""Constructs a Kubeflow pipeline.
Creates Kubeflow ContainerOps for each TFX component encountered in the
logical pipeline definition.
"""
"""Creates Kubeflow ContainerOps for each TFX component
encountered in the pipeline definition."""
self._construct_pipeline_graph(pipeline)

# Need to run this first to get self._params populated. Then KFP compiler
# can correctly match default value with PipelineParam.
# Need to run this first to get self._params populated. Then KFP
# compiler can correctly match default value with PipelineParam.
self._parse_parameter_from_pipeline(pipeline)
# Create workflow spec and write out to package.
self._compiler._create_and_write_workflow(
Expand Down
25 changes: 0 additions & 25 deletions src/zenml/integrations/kubeflow/orchestrators/kubeflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
# limitations under the License.
"""Common utility for Kubeflow-based orchestrator."""


# utils.py should not be used in container_entrypoint.py because of its
# dependency on KFP.
from kfp import dsl
from tfx.dsl.components.base import base_node
from tfx.orchestration import data_types
Expand All @@ -31,25 +28,3 @@ def replace_placeholder(component: base_node.BaseNode) -> None:
component.exec_properties[key] = str(
dsl.PipelineParam(name=exec_property.name)
)


def fix_brackets(placeholder: str) -> str:
"""Fix the imbalanced brackets in placeholder.
When ptype is not null, regex matching might grab a placeholder with }
missing. This function fix the missing bracket.
Args:
placeholder: string placeholder of RuntimeParameter
Returns:
Placeholder with re-balanced brackets.
Raises:
RuntimeError: if left brackets are less than right brackets.
"""
lcount = placeholder.count("{")
rcount = placeholder.count("}")
if lcount < rcount:
raise RuntimeError(
"Unexpected redundant left brackets found in {}".format(placeholder)
)
else:
patch = "".join(["}"] * (lcount - rcount))
return placeholder + patch