From 458f1d755c7bedc47db7f2901b65819495cfd543 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 8 Mar 2022 15:29:33 +0100 Subject: [PATCH] Ignore type of any tfx proto file --- pyproject.toml | 5 +++ .../kubeflow/container_entrypoint.py | 40 ++++++++----------- .../orchestrators/kubeflow_component.py | 10 ++--- .../orchestrators/kubeflow_dag_runner.py | 28 +++++-------- src/zenml/orchestrators/context_utils.py | 14 +++---- .../orchestrators/local/local_orchestrator.py | 12 +++--- src/zenml/orchestrators/utils.py | 10 ++--- src/zenml/steps/utils.py | 2 +- 8 files changed, 56 insertions(+), 65 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d0247c3fb5..7c3568991a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,6 +172,11 @@ follow_imports = "skip" # end of fix +# import all tfx.proto.* files as `Any` +[[tool.mypy.overrides]] +module = "tfx.proto.*" +follow_imports = "skip" + [[tool.mypy.overrides]] module = [ "tensorflow.*", diff --git a/src/zenml/integrations/kubeflow/container_entrypoint.py b/src/zenml/integrations/kubeflow/container_entrypoint.py index df97a76f16e..bee5e4c9189 100644 --- a/src/zenml/integrations/kubeflow/container_entrypoint.py +++ b/src/zenml/integrations/kubeflow/container_entrypoint.py @@ -146,7 +146,7 @@ def _render_artifact_as_mdstr(single_artifact: artifact.Artifact) -> str: def _dump_ui_metadata( - node: pipeline_pb2.PipelineNode, # type: ignore[valid-type] + node: pipeline_pb2.PipelineNode, execution_info: data_types.ExecutionInfo, ui_metadata_path: str = "/tmp/mlpipeline-ui-metadata.json", ) -> None: @@ -172,7 +172,7 @@ def _dump_ui_metadata( ) def _dump_input_populated_artifacts( - node_inputs: MutableMapping[str, pipeline_pb2.InputSpec], # type: ignore[valid-type] # noqa + node_inputs: MutableMapping[str, pipeline_pb2.InputSpec], name_to_artifacts: Dict[str, List[artifact.Artifact]], ) -> List[str]: """Dump artifacts markdown string for inputs. @@ -195,7 +195,7 @@ def _dump_input_populated_artifacts( ) # There must be at least a channel in a input, and all channels in # a input share the same artifact type. - artifact_type = spec.channels[0].artifact_query.type.name # type: ignore[attr-defined] # noqa + artifact_type = spec.channels[0].artifact_query.type.name rendered_list.append( "## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format( name=_sanitize_underscore(name), @@ -207,7 +207,7 @@ def _dump_input_populated_artifacts( return rendered_list def _dump_output_populated_artifacts( - node_outputs: MutableMapping[str, pipeline_pb2.OutputSpec], # type: ignore[valid-type] # noqa + node_outputs: MutableMapping[str, pipeline_pb2.OutputSpec], name_to_artifacts: Dict[str, List[artifact.Artifact]], ) -> List[str]: """Dump artifacts markdown string for outputs. @@ -230,7 +230,7 @@ def _dump_output_populated_artifacts( ) # There must be at least a channel in a input, and all channels # in a input share the same artifact type. - artifact_type = spec.artifact_spec.type.name # type: ignore[attr-defined] # noqa + artifact_type = spec.artifact_spec.type.name rendered_list.append( "## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format( name=_sanitize_underscore(name), @@ -244,7 +244,7 @@ def _dump_output_populated_artifacts( src_str_inputs = "# Inputs:\n{}".format( "".join( _dump_input_populated_artifacts( - node_inputs=node.inputs.inputs, # type: ignore[attr-defined] # noqa + node_inputs=node.inputs.inputs, name_to_artifacts=execution_info.input_dict or {}, ) ) @@ -254,7 +254,7 @@ def _dump_output_populated_artifacts( src_str_outputs = "# Outputs:\n{}".format( "".join( _dump_output_populated_artifacts( - node_outputs=node.outputs.outputs, # type: ignore[attr-defined] # noqa + node_outputs=node.outputs.outputs, name_to_artifacts=execution_info.output_dict or {}, ) ) @@ -273,7 +273,7 @@ def _dump_output_populated_artifacts( } ] # Add Tensorboard view for ModelRun outputs. - for name, spec in node.outputs.outputs.items(): # type: ignore[attr-defined] # noqa + for name, spec in node.outputs.outputs.items(): if ( spec.artifact_spec.type.name == standard_artifacts.ModelRun.TYPE_NAME @@ -294,11 +294,11 @@ def _dump_output_populated_artifacts( def _get_pipeline_node( - pipeline: pipeline_pb2.Pipeline, node_id: str # type: ignore[valid-type] # noqa -) -> pipeline_pb2.PipelineNode: # type: ignore[valid-type] + pipeline: pipeline_pb2.Pipeline, node_id: str +) -> pipeline_pb2.PipelineNode: """Gets node of a certain node_id from a pipeline.""" - result: Optional[pipeline_pb2.PipelineNode] = None # type: ignore[valid-type] # noqa - for node in pipeline.nodes: # type: ignore[attr-defined] # noqa + result: Optional[pipeline_pb2.PipelineNode] = None + for node in pipeline.nodes: if ( node.WhichOneof("node") == "pipeline_node" and node.pipeline_node.node_info.id == node_id @@ -318,25 +318,19 @@ def _parse_runtime_parameter_str(param: str) -> Tuple[str, Property]: # Runtime parameter format: "{name}=(INT|DOUBLE|STRING):{value}" name, value_and_type = param.split("=", 1) value_type, value = value_and_type.split(":", 1) - if ( - value_type - == pipeline_pb2.RuntimeParameter.Type.Name( # type: ignore[attr-defined] # noqa - pipeline_pb2.RuntimeParameter.INT # type: ignore[attr-defined] - ) + if value_type == pipeline_pb2.RuntimeParameter.Type.Name( + pipeline_pb2.RuntimeParameter.INT ): return name, int(value) - elif ( - value_type - == pipeline_pb2.RuntimeParameter.Type.Name( # type: ignore[attr-defined] # noqa - pipeline_pb2.RuntimeParameter.DOUBLE # type: ignore[attr-defined] - ) + elif value_type == pipeline_pb2.RuntimeParameter.Type.Name( + pipeline_pb2.RuntimeParameter.DOUBLE ): return name, float(value) return name, value def _resolve_runtime_parameters( - tfx_ir: pipeline_pb2.Pipeline, # type: ignore[valid-type] # noqa + tfx_ir: pipeline_pb2.Pipeline, run_name: str, parameters: Optional[List[str]], ) -> None: diff --git a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_component.py b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_component.py index 11837187e94..a0ba2b06851 100644 --- a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_component.py +++ b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_component.py @@ -52,12 +52,12 @@ def _encode_runtime_parameter(param: data_types.RuntimeParameter) -> str: """Encode a runtime parameter into a placeholder for value substitution.""" if param.ptype is int: - type_enum = pipeline_pb2.RuntimeParameter.INT # type: ignore[attr-defined] # noqa + type_enum = pipeline_pb2.RuntimeParameter.INT elif param.ptype is float: - type_enum = pipeline_pb2.RuntimeParameter.DOUBLE # type: ignore[attr-defined] # noqa + type_enum = pipeline_pb2.RuntimeParameter.DOUBLE else: - type_enum = pipeline_pb2.RuntimeParameter.STRING # type: ignore[attr-defined] # noqa - type_str = pipeline_pb2.RuntimeParameter.Type.Name(type_enum) # type: ignore[attr-defined] # noqa + type_enum = pipeline_pb2.RuntimeParameter.STRING + type_str = pipeline_pb2.RuntimeParameter.Type.Name(type_enum) return f"{param.name}={type_str}:{str(dsl.PipelineParam(name=param.name))}" @@ -90,7 +90,7 @@ def __init__( component: tfx_base_component.BaseComponent, depends_on: Set[dsl.ContainerOp], image: str, - tfx_ir: pipeline_pb2.Pipeline, # type: ignore[valid-type] + tfx_ir: pipeline_pb2.Pipeline, pod_labels_to_attach: Dict[str, str], main_module: str, step_module: str, diff --git a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py index 1acc3e41804..287fe012f54 100644 --- a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py +++ b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_dag_runner.py @@ -306,14 +306,10 @@ def _construct_pipeline_graph( runtime_configuration: The runtime configuration """ component_to_kfp_op: Dict[base_node.BaseNode, dsl.ContainerOp] = {} - tfx_ir: Pb2Pipeline = self._generate_tfx_ir( # type:ignore[valid-type] - pipeline - ) + tfx_ir: Pb2Pipeline = self._generate_tfx_ir(pipeline) - for node in tfx_ir.nodes: # type:ignore[attr-defined] - pipeline_node: PipelineNode = ( # type:ignore[valid-type] - node.pipeline_node - ) + for node in tfx_ir.nodes: + pipeline_node: PipelineNode = node.pipeline_node # Add the stack as context to each pipeline node: context_utils.add_context_to_node( @@ -382,32 +378,30 @@ def _del_unused_field( del message_dict[item] def _dehydrate_tfx_ir( - self, original_pipeline: Pb2Pipeline, node_id: str # type: ignore[valid-type] # noqa - ) -> Pb2Pipeline: # type: ignore[valid-type] + self, original_pipeline: Pb2Pipeline, node_id: str + ) -> Pb2Pipeline: """Dehydrate the TFX IR to remove unused fields.""" pipeline = copy.deepcopy(original_pipeline) - for node in pipeline.nodes: # type: ignore[attr-defined] + for node in pipeline.nodes: if ( node.WhichOneof("node") == "pipeline_node" and node.pipeline_node.node_info.id == node_id ): - del pipeline.nodes[:] # type: ignore[attr-defined] - pipeline.nodes.extend([node]) # type: ignore[attr-defined] + del pipeline.nodes[:] + pipeline.nodes.extend([node]) break deployment_config = IntermediateDeploymentConfig() - pipeline.deployment_config.Unpack(deployment_config) # type: ignore[attr-defined] # noqa + pipeline.deployment_config.Unpack(deployment_config) self._del_unused_field(node_id, deployment_config.executor_specs) self._del_unused_field(node_id, deployment_config.custom_driver_specs) self._del_unused_field( node_id, deployment_config.node_level_platform_configs ) - pipeline.deployment_config.Pack(deployment_config) # type: ignore[attr-defined] # noqa + pipeline.deployment_config.Pack(deployment_config) return pipeline - def _generate_tfx_ir( - self, pipeline: TfxPipeline - ) -> Pb2Pipeline: # type: ignore[valid-type] + def _generate_tfx_ir(self, pipeline: TfxPipeline) -> Pb2Pipeline: """Generate the TFX IR from the logical TFX pipeline.""" result = self._tfx_compiler.compile(pipeline) return result diff --git a/src/zenml/orchestrators/context_utils.py b/src/zenml/orchestrators/context_utils.py index aaefeac1fe4..ecb6654e5ae 100644 --- a/src/zenml/orchestrators/context_utils.py +++ b/src/zenml/orchestrators/context_utils.py @@ -29,7 +29,7 @@ def add_context_to_node( - pipeline_node: "pipeline_pb2.PipelineNode", # type: ignore[valid-type] + pipeline_node: "pipeline_pb2.PipelineNode", type_: str, name: str, properties: Dict[str, str], @@ -44,16 +44,14 @@ def add_context_to_node( properties: dictionary of strings as properties of the context """ # Add a new context to the pipeline - context: "pipeline_pb2.ContextSpec" = ( # type: ignore[valid-type] - pipeline_node.contexts.contexts.add() # type: ignore[attr-defined] - ) + context: "pipeline_pb2.ContextSpec" = pipeline_node.contexts.contexts.add() # Adding the type of context - context.type.name = type_ # type: ignore[attr-defined] + context.type.name = type_ # Setting the name of the context - context.name.field_value.string_value = name # type: ignore[attr-defined] + context.name.field_value.string_value = name # Setting the properties of the context depending on attribute type for key, value in properties.items(): - c_property = context.properties[key] # type:ignore[attr-defined] + c_property = context.properties[key] c_property.field_value.string_value = value @@ -94,7 +92,7 @@ def _inner_generator( def add_runtime_configuration_to_node( - pipeline_node: "pipeline_pb2.PipelineNode", # type: ignore[valid-type] + pipeline_node: "pipeline_pb2.PipelineNode", runtime_config: RuntimeConfiguration, ) -> None: """ diff --git a/src/zenml/orchestrators/local/local_orchestrator.py b/src/zenml/orchestrators/local/local_orchestrator.py index e273e45ebd4..d68e20cd9b8 100644 --- a/src/zenml/orchestrators/local/local_orchestrator.py +++ b/src/zenml/orchestrators/local/local_orchestrator.py @@ -93,7 +93,7 @@ def run_pipeline( tfx_pipeline.pipeline_info.pipeline_root ) - pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline) # type: ignore[valid-type] + pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline) # Substitute the runtime parameter to be a concrete run_id runtime_parameter_utils.substitute_runtime_parameter( @@ -115,8 +115,8 @@ def run_pipeline( # Run each component. Note that the pipeline.components list is in # topological order. - for node in pb2_pipeline.nodes: # type: ignore[attr-defined] - pipeline_node: PipelineNode = node.pipeline_node # type: ignore[valid-type] + for node in pb2_pipeline.nodes: + pipeline_node: PipelineNode = node.pipeline_node # fill out that context context_utils.add_context_to_node( @@ -131,7 +131,7 @@ def run_pipeline( pipeline_node, runtime_configuration ) - node_id = pipeline_node.node_info.id # type:ignore[attr-defined] + node_id = pipeline_node.node_info.id executor_spec = runner_utils.extract_executor_spec( deployment_config, node_id ) @@ -139,8 +139,8 @@ def run_pipeline( deployment_config, node_id ) - p_info = pb2_pipeline.pipeline_info # type:ignore[attr-defined] - r_spec = pb2_pipeline.runtime_spec # type:ignore[attr-defined] + p_info = pb2_pipeline.pipeline_info + r_spec = pb2_pipeline.runtime_spec component_launcher = launcher.Launcher( pipeline_node=pipeline_node, diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 7d3eb969889..9a4c26092f0 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -14,7 +14,7 @@ import json import time -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, cast import tfx.orchestration.pipeline as tfx_pipeline from tfx.orchestration.portable import data_types, launcher @@ -72,8 +72,8 @@ def get_cache_status( The caching status of a `tfx` step as a boolean value. """ if execution_info is None: - logger.warn("No execution info found when checking for cache status.") - return + logger.warning("No execution info found when checking cache status.") + return False status = False repository = Repository() @@ -91,7 +91,7 @@ def get_cache_status( pipeline_name = execution_info.pipeline_info.id else: raise KeyError(f"No pipeline info found for step `{step_name}`.") - pipeline_run_name = execution_info.pipeline_run_id + pipeline_run_name = cast(str, execution_info.pipeline_run_id) pipeline = metadata_store.get_pipeline(pipeline_name) if pipeline is None: logger.error(f"Pipeline {pipeline_name} not found in Metadata Store.") @@ -116,7 +116,7 @@ def execute_step( step_name_param = ( INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME ) - pipeline_step_name = tfx_launcher._pipeline_node.node_info.id # type: ignore[attr-defined] + pipeline_step_name = tfx_launcher._pipeline_node.node_info.id start_time = time.time() logger.info(f"Step `{pipeline_step_name}` has started.") try: diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index 39a57bbf160..0c697a5472f 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -436,7 +436,7 @@ def Do( # the pipeline runtime, such as the current step name and the current # pipeline run ID with StepEnvironment( - pipeline_name=self._context.pipeline_info.id, # type: ignore[attr-defined] + pipeline_name=self._context.pipeline_info.id, pipeline_run_id=self._context.pipeline_run_id, step_name=getattr(self, PARAM_STEP_NAME), ):