Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore type of any tfx proto file #453

Merged
merged 1 commit into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ follow_imports = "skip"

# end of fix

# import all tfx.proto.* files as `Any`
[[tool.mypy.overrides]]
module = "tfx.proto.*"
follow_imports = "skip"

[[tool.mypy.overrides]]
module = [
"tensorflow.*",
Expand Down
40 changes: 17 additions & 23 deletions src/zenml/integrations/kubeflow/container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _render_artifact_as_mdstr(single_artifact: artifact.Artifact) -> str:


def _dump_ui_metadata(
node: pipeline_pb2.PipelineNode, # type: ignore[valid-type]
node: pipeline_pb2.PipelineNode,
execution_info: data_types.ExecutionInfo,
ui_metadata_path: str = "/tmp/mlpipeline-ui-metadata.json",
) -> None:
Expand All @@ -172,7 +172,7 @@ def _dump_ui_metadata(
)

def _dump_input_populated_artifacts(
node_inputs: MutableMapping[str, pipeline_pb2.InputSpec], # type: ignore[valid-type] # noqa
node_inputs: MutableMapping[str, pipeline_pb2.InputSpec],
name_to_artifacts: Dict[str, List[artifact.Artifact]],
) -> List[str]:
"""Dump artifacts markdown string for inputs.
Expand All @@ -195,7 +195,7 @@ def _dump_input_populated_artifacts(
)
# There must be at least a channel in a input, and all channels in
# a input share the same artifact type.
artifact_type = spec.channels[0].artifact_query.type.name # type: ignore[attr-defined] # noqa
artifact_type = spec.channels[0].artifact_query.type.name
rendered_list.append(
"## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
name=_sanitize_underscore(name),
Expand All @@ -207,7 +207,7 @@ def _dump_input_populated_artifacts(
return rendered_list

def _dump_output_populated_artifacts(
node_outputs: MutableMapping[str, pipeline_pb2.OutputSpec], # type: ignore[valid-type] # noqa
node_outputs: MutableMapping[str, pipeline_pb2.OutputSpec],
name_to_artifacts: Dict[str, List[artifact.Artifact]],
) -> List[str]:
"""Dump artifacts markdown string for outputs.
Expand All @@ -230,7 +230,7 @@ def _dump_output_populated_artifacts(
)
# There must be at least a channel in a input, and all channels
# in a input share the same artifact type.
artifact_type = spec.artifact_spec.type.name # type: ignore[attr-defined] # noqa
artifact_type = spec.artifact_spec.type.name
rendered_list.append(
"## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
name=_sanitize_underscore(name),
Expand All @@ -244,7 +244,7 @@ def _dump_output_populated_artifacts(
src_str_inputs = "# Inputs:\n{}".format(
"".join(
_dump_input_populated_artifacts(
node_inputs=node.inputs.inputs, # type: ignore[attr-defined] # noqa
node_inputs=node.inputs.inputs,
name_to_artifacts=execution_info.input_dict or {},
)
)
Expand All @@ -254,7 +254,7 @@ def _dump_output_populated_artifacts(
src_str_outputs = "# Outputs:\n{}".format(
"".join(
_dump_output_populated_artifacts(
node_outputs=node.outputs.outputs, # type: ignore[attr-defined] # noqa
node_outputs=node.outputs.outputs,
name_to_artifacts=execution_info.output_dict or {},
)
)
Expand All @@ -273,7 +273,7 @@ def _dump_output_populated_artifacts(
}
]
# Add Tensorboard view for ModelRun outputs.
for name, spec in node.outputs.outputs.items(): # type: ignore[attr-defined] # noqa
for name, spec in node.outputs.outputs.items():
if (
spec.artifact_spec.type.name
== standard_artifacts.ModelRun.TYPE_NAME
Expand All @@ -294,11 +294,11 @@ def _dump_output_populated_artifacts(


def _get_pipeline_node(
pipeline: pipeline_pb2.Pipeline, node_id: str # type: ignore[valid-type] # noqa
) -> pipeline_pb2.PipelineNode: # type: ignore[valid-type]
pipeline: pipeline_pb2.Pipeline, node_id: str
) -> pipeline_pb2.PipelineNode:
"""Gets node of a certain node_id from a pipeline."""
result: Optional[pipeline_pb2.PipelineNode] = None # type: ignore[valid-type] # noqa
for node in pipeline.nodes: # type: ignore[attr-defined] # noqa
result: Optional[pipeline_pb2.PipelineNode] = None
for node in pipeline.nodes:
if (
node.WhichOneof("node") == "pipeline_node"
and node.pipeline_node.node_info.id == node_id
Expand All @@ -318,25 +318,19 @@ def _parse_runtime_parameter_str(param: str) -> Tuple[str, Property]:
# Runtime parameter format: "{name}=(INT|DOUBLE|STRING):{value}"
name, value_and_type = param.split("=", 1)
value_type, value = value_and_type.split(":", 1)
if (
value_type
== pipeline_pb2.RuntimeParameter.Type.Name( # type: ignore[attr-defined] # noqa
pipeline_pb2.RuntimeParameter.INT # type: ignore[attr-defined]
)
if value_type == pipeline_pb2.RuntimeParameter.Type.Name(
pipeline_pb2.RuntimeParameter.INT
):
return name, int(value)
elif (
value_type
== pipeline_pb2.RuntimeParameter.Type.Name( # type: ignore[attr-defined] # noqa
pipeline_pb2.RuntimeParameter.DOUBLE # type: ignore[attr-defined]
)
elif value_type == pipeline_pb2.RuntimeParameter.Type.Name(
pipeline_pb2.RuntimeParameter.DOUBLE
):
return name, float(value)
return name, value


def _resolve_runtime_parameters(
tfx_ir: pipeline_pb2.Pipeline, # type: ignore[valid-type] # noqa
tfx_ir: pipeline_pb2.Pipeline,
run_name: str,
parameters: Optional[List[str]],
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@
def _encode_runtime_parameter(param: data_types.RuntimeParameter) -> str:
"""Encode a runtime parameter into a placeholder for value substitution."""
if param.ptype is int:
type_enum = pipeline_pb2.RuntimeParameter.INT # type: ignore[attr-defined] # noqa
type_enum = pipeline_pb2.RuntimeParameter.INT
elif param.ptype is float:
type_enum = pipeline_pb2.RuntimeParameter.DOUBLE # type: ignore[attr-defined] # noqa
type_enum = pipeline_pb2.RuntimeParameter.DOUBLE
else:
type_enum = pipeline_pb2.RuntimeParameter.STRING # type: ignore[attr-defined] # noqa
type_str = pipeline_pb2.RuntimeParameter.Type.Name(type_enum) # type: ignore[attr-defined] # noqa
type_enum = pipeline_pb2.RuntimeParameter.STRING
type_str = pipeline_pb2.RuntimeParameter.Type.Name(type_enum)
return f"{param.name}={type_str}:{str(dsl.PipelineParam(name=param.name))}"


Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(
component: tfx_base_component.BaseComponent,
depends_on: Set[dsl.ContainerOp],
image: str,
tfx_ir: pipeline_pb2.Pipeline, # type: ignore[valid-type]
tfx_ir: pipeline_pb2.Pipeline,
pod_labels_to_attach: Dict[str, str],
main_module: str,
step_module: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,10 @@ def _construct_pipeline_graph(
runtime_configuration: The runtime configuration
"""
component_to_kfp_op: Dict[base_node.BaseNode, dsl.ContainerOp] = {}
tfx_ir: Pb2Pipeline = self._generate_tfx_ir( # type:ignore[valid-type]
pipeline
)
tfx_ir: Pb2Pipeline = self._generate_tfx_ir(pipeline)

for node in tfx_ir.nodes: # type:ignore[attr-defined]
pipeline_node: PipelineNode = ( # type:ignore[valid-type]
node.pipeline_node
)
for node in tfx_ir.nodes:
pipeline_node: PipelineNode = node.pipeline_node

# Add the stack as context to each pipeline node:
context_utils.add_context_to_node(
Expand Down Expand Up @@ -382,32 +378,30 @@ def _del_unused_field(
del message_dict[item]

def _dehydrate_tfx_ir(
self, original_pipeline: Pb2Pipeline, node_id: str # type: ignore[valid-type] # noqa
) -> Pb2Pipeline: # type: ignore[valid-type]
self, original_pipeline: Pb2Pipeline, node_id: str
) -> Pb2Pipeline:
"""Dehydrate the TFX IR to remove unused fields."""
pipeline = copy.deepcopy(original_pipeline)
for node in pipeline.nodes: # type: ignore[attr-defined]
for node in pipeline.nodes:
if (
node.WhichOneof("node") == "pipeline_node"
and node.pipeline_node.node_info.id == node_id
):
del pipeline.nodes[:] # type: ignore[attr-defined]
pipeline.nodes.extend([node]) # type: ignore[attr-defined]
del pipeline.nodes[:]
pipeline.nodes.extend([node])
break

deployment_config = IntermediateDeploymentConfig()
pipeline.deployment_config.Unpack(deployment_config) # type: ignore[attr-defined] # noqa
pipeline.deployment_config.Unpack(deployment_config)
self._del_unused_field(node_id, deployment_config.executor_specs)
self._del_unused_field(node_id, deployment_config.custom_driver_specs)
self._del_unused_field(
node_id, deployment_config.node_level_platform_configs
)
pipeline.deployment_config.Pack(deployment_config) # type: ignore[attr-defined] # noqa
pipeline.deployment_config.Pack(deployment_config)
return pipeline

def _generate_tfx_ir(
self, pipeline: TfxPipeline
) -> Pb2Pipeline: # type: ignore[valid-type]
def _generate_tfx_ir(self, pipeline: TfxPipeline) -> Pb2Pipeline:
"""Generate the TFX IR from the logical TFX pipeline."""
result = self._tfx_compiler.compile(pipeline)
return result
Expand Down
14 changes: 6 additions & 8 deletions src/zenml/orchestrators/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


def add_context_to_node(
pipeline_node: "pipeline_pb2.PipelineNode", # type: ignore[valid-type]
pipeline_node: "pipeline_pb2.PipelineNode",
type_: str,
name: str,
properties: Dict[str, str],
Expand All @@ -44,16 +44,14 @@ def add_context_to_node(
properties: dictionary of strings as properties of the context
"""
# Add a new context to the pipeline
context: "pipeline_pb2.ContextSpec" = ( # type: ignore[valid-type]
pipeline_node.contexts.contexts.add() # type: ignore[attr-defined]
)
context: "pipeline_pb2.ContextSpec" = pipeline_node.contexts.contexts.add()
# Adding the type of context
context.type.name = type_ # type: ignore[attr-defined]
context.type.name = type_
# Setting the name of the context
context.name.field_value.string_value = name # type: ignore[attr-defined]
context.name.field_value.string_value = name
# Setting the properties of the context depending on attribute type
for key, value in properties.items():
c_property = context.properties[key] # type:ignore[attr-defined]
c_property = context.properties[key]
c_property.field_value.string_value = value


Expand Down Expand Up @@ -94,7 +92,7 @@ def _inner_generator(


def add_runtime_configuration_to_node(
pipeline_node: "pipeline_pb2.PipelineNode", # type: ignore[valid-type]
pipeline_node: "pipeline_pb2.PipelineNode",
runtime_config: RuntimeConfiguration,
) -> None:
"""
Expand Down
12 changes: 6 additions & 6 deletions src/zenml/orchestrators/local/local_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def run_pipeline(
tfx_pipeline.pipeline_info.pipeline_root
)

pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline) # type: ignore[valid-type]
pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)

# Substitute the runtime parameter to be a concrete run_id
runtime_parameter_utils.substitute_runtime_parameter(
Expand All @@ -115,8 +115,8 @@ def run_pipeline(

# Run each component. Note that the pipeline.components list is in
# topological order.
for node in pb2_pipeline.nodes: # type: ignore[attr-defined]
pipeline_node: PipelineNode = node.pipeline_node # type: ignore[valid-type]
for node in pb2_pipeline.nodes:
pipeline_node: PipelineNode = node.pipeline_node

# fill out that context
context_utils.add_context_to_node(
Expand All @@ -131,16 +131,16 @@ def run_pipeline(
pipeline_node, runtime_configuration
)

node_id = pipeline_node.node_info.id # type:ignore[attr-defined]
node_id = pipeline_node.node_info.id
executor_spec = runner_utils.extract_executor_spec(
deployment_config, node_id
)
custom_driver_spec = runner_utils.extract_custom_driver_spec(
deployment_config, node_id
)

p_info = pb2_pipeline.pipeline_info # type:ignore[attr-defined]
r_spec = pb2_pipeline.runtime_spec # type:ignore[attr-defined]
p_info = pb2_pipeline.pipeline_info
r_spec = pb2_pipeline.runtime_spec

component_launcher = launcher.Launcher(
pipeline_node=pipeline_node,
Expand Down
10 changes: 5 additions & 5 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import json
import time
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, cast

import tfx.orchestration.pipeline as tfx_pipeline
from tfx.orchestration.portable import data_types, launcher
Expand Down Expand Up @@ -72,8 +72,8 @@ def get_cache_status(
The caching status of a `tfx` step as a boolean value.
"""
if execution_info is None:
logger.warn("No execution info found when checking for cache status.")
return
logger.warning("No execution info found when checking cache status.")
return False

status = False
repository = Repository()
Expand All @@ -91,7 +91,7 @@ def get_cache_status(
pipeline_name = execution_info.pipeline_info.id
else:
raise KeyError(f"No pipeline info found for step `{step_name}`.")
pipeline_run_name = execution_info.pipeline_run_id
pipeline_run_name = cast(str, execution_info.pipeline_run_id)
pipeline = metadata_store.get_pipeline(pipeline_name)
if pipeline is None:
logger.error(f"Pipeline {pipeline_name} not found in Metadata Store.")
Expand All @@ -116,7 +116,7 @@ def execute_step(
step_name_param = (
INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
)
pipeline_step_name = tfx_launcher._pipeline_node.node_info.id # type: ignore[attr-defined]
pipeline_step_name = tfx_launcher._pipeline_node.node_info.id
start_time = time.time()
logger.info(f"Step `{pipeline_step_name}` has started.")
try:
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def Do(
# the pipeline runtime, such as the current step name and the current
# pipeline run ID
with StepEnvironment(
pipeline_name=self._context.pipeline_info.id, # type: ignore[attr-defined]
pipeline_name=self._context.pipeline_info.id,
pipeline_run_id=self._context.pipeline_run_id,
step_name=getattr(self, PARAM_STEP_NAME),
):
Expand Down