From 2b4279f9eecc94713e33ee35fd43c883779f6e0e Mon Sep 17 00:00:00 2001 From: jankuehle Date: Tue, 13 Aug 2024 06:55:59 -0700 Subject: [PATCH] Fix or ignore some pytype errors. PiperOrigin-RevId: 662500667 --- .../input_resolution/ops/latest_policy_model_op.py | 8 ++++---- tfx/dsl/input_resolution/ops/test_utils.py | 12 ++++++------ tfx/orchestration/metadata.py | 2 +- tfx/orchestration/portable/importer_node_handler.py | 2 +- tfx/orchestration/portable/partial_run_utils.py | 2 +- tfx/orchestration/portable/resolver_node_handler.py | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py index ac061466fb..6726587654 100644 --- a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py +++ b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py @@ -79,13 +79,13 @@ def add_downstream_artifact( """Adds a downstream artifact to the ModelRelations.""" artifact_type_name = downstream_artifact.type if _is_eval_blessed(artifact_type_name, downstream_artifact): - self.model_blessing_artifacts.append(downstream_artifact) + self.model_blessing_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type elif _is_infra_blessed(artifact_type_name, downstream_artifact): - self.infra_blessing_artifacts.append(downstream_artifact) + self.infra_blessing_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type elif artifact_type_name == ops_utils.MODEL_PUSH_TYPE_NAME: - self.model_push_artifacts.append(downstream_artifact) + self.model_push_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type def meets_policy(self, policy: Policy) -> bool: """Checks if ModelRelations contains artifacts that meet the Policy.""" @@ -486,7 +486,7 @@ def event_filter(event): ] # Set `max_num_hops` to 50, which should be enough for this use case. batch_downstream_artifacts_and_types_by_model_identifier = ( - mlmd_resolver.get_downstream_artifacts_by_artifacts( + mlmd_resolver.get_downstream_artifacts_by_artifacts( # pytype: disable=wrong-arg-types # dont-delete-module-type batch_model_artifacts, max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS, filter_query=filter_query, diff --git a/tfx/dsl/input_resolution/ops/test_utils.py b/tfx/dsl/input_resolution/ops/test_utils.py index 1d4b0705b5..7dc9894649 100644 --- a/tfx/dsl/input_resolution/ops/test_utils.py +++ b/tfx/dsl/input_resolution/ops/test_utils.py @@ -256,7 +256,7 @@ def create_examples( ) self.put_execution( 'ExampleGen', - inputs={}, + inputs={}, # pytype: disable=wrong-arg-types # dont-delete-module-type outputs={'examples': self.unwrap_tfx_artifacts(examples)}, contexts=contexts, connection_config=connection_config, @@ -275,7 +275,7 @@ def transform_examples( ) self.put_execution( 'Transform', - inputs=inputs, + inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type outputs={ 'transform_graph': self.unwrap_tfx_artifacts([transform_graph]) }, @@ -298,7 +298,7 @@ def train_on_examples( inputs['transform_graph'] = self.unwrap_tfx_artifacts([transform_graph]) self.put_execution( 'TFTrainer', - inputs=inputs, + inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type outputs={'model': self.unwrap_tfx_artifacts([model])}, contexts=contexts, connection_config=connection_config, @@ -325,7 +325,7 @@ def evaluator_bless_model( self.put_execution( 'Evaluator', - inputs=inputs, + inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type outputs={'blessing': self.unwrap_tfx_artifacts([model_blessing])}, contexts=contexts, connection_config=connection_config, @@ -353,7 +353,7 @@ def infra_validator_bless_model( self.put_execution( 'InfraValidator', - inputs={'model': self.unwrap_tfx_artifacts([model])}, + inputs={'model': self.unwrap_tfx_artifacts([model])}, # pytype: disable=wrong-arg-types # dont-delete-module-type outputs={'result': self.unwrap_tfx_artifacts([model_infra_blessing])}, contexts=contexts, connection_config=connection_config, @@ -375,7 +375,7 @@ def push_model( ) self.put_execution( 'ServomaticPusher', - inputs={'model_export': self.unwrap_tfx_artifacts([model])}, + inputs={'model_export': self.unwrap_tfx_artifacts([model])}, # pytype: disable=wrong-arg-types # dont-delete-module-type outputs={'model_push': self.unwrap_tfx_artifacts([model_push])}, contexts=contexts, connection_config=connection_config, diff --git a/tfx/orchestration/metadata.py b/tfx/orchestration/metadata.py index 427ec96fbd..ab740646ba 100644 --- a/tfx/orchestration/metadata.py +++ b/tfx/orchestration/metadata.py @@ -267,7 +267,7 @@ def get_published_artifacts_by_type_within_context( @staticmethod def _get_legacy_producer_component_id( execution: metadata_store_pb2.Execution) -> str: - return execution.properties[_EXECUTION_TYPE_KEY_COMPONENT_ID].string_value + return execution.properties[_EXECUTION_TYPE_KEY_COMPONENT_ID].string_value # pytype: disable=bad-return-type # dont-delete-module-type def get_qualified_artifacts( self, diff --git a/tfx/orchestration/portable/importer_node_handler.py b/tfx/orchestration/portable/importer_node_handler.py index d3997e8a86..1448a3ea4d 100644 --- a/tfx/orchestration/portable/importer_node_handler.py +++ b/tfx/orchestration/portable/importer_node_handler.py @@ -47,7 +47,7 @@ def _extract_proto_map( extract_mlmd_value = lambda v: getattr(v, v.WhichOneof('value')) return {k: extract_mlmd_value(v.field_value) for k, v in proto_map.items()} - def run( + def run( # pytype: disable=signature-mismatch # dont-delete-module-type self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, diff --git a/tfx/orchestration/portable/partial_run_utils.py b/tfx/orchestration/portable/partial_run_utils.py index fe701e9a2c..1087aa0f50 100644 --- a/tfx/orchestration/portable/partial_run_utils.py +++ b/tfx/orchestration/portable/partial_run_utils.py @@ -639,7 +639,7 @@ def _get_base_pipeline_run_context( pipeline_run_contexts, key=lambda c: c.create_time_since_epoch ) if not sorted_run_contexts: - return None + return None # pytype: disable=bad-return-type # dont-delete-module-type logging.info( 'base_run_id not provided. Default to latest pipeline run: %s', diff --git a/tfx/orchestration/portable/resolver_node_handler.py b/tfx/orchestration/portable/resolver_node_handler.py index 221d6bb278..b833d466ea 100644 --- a/tfx/orchestration/portable/resolver_node_handler.py +++ b/tfx/orchestration/portable/resolver_node_handler.py @@ -42,7 +42,7 @@ def _extract_proto_map( extract_mlmd_value = lambda v: getattr(v, v.WhichOneof('value')) return {k: extract_mlmd_value(v.field_value) for k, v in proto_map.items()} - def run( + def run( # pytype: disable=signature-mismatch # dont-delete-module-type self, mlmd_connection: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo,