From e7697ec13d26725aeb220cf815194dbb194cab61 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Sun, 24 Jul 2022 14:52:12 +0200 Subject: [PATCH] Fix mypy errors by using cached_property instead of explicit caching --- .../databricks/operators/databricks.py | 64 +++++++++---------- .../databricks/operators/databricks_repos.py | 29 +++++---- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index d539e3d2eb9e5..1ecf23639decd 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -22,6 +22,7 @@ from logging import Logger from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState @@ -382,28 +383,27 @@ def __init__( # This variable will be used in case our task gets killed. self.run_id: Optional[int] = None self.do_xcom_push = do_xcom_push - self.hook = None - - def _get_hook(self, caller="DatabricksSubmitRunOperator") -> DatabricksHook: - if not self.hook: - self.hook = DatabricksHook( - self.databricks_conn_id, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - retry_args=self.databricks_retry_args, - caller=caller, - ) - return self.hook + + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksSubmitRunOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) def execute(self, context: 'Context'): - hook = self._get_hook() - self.run_id = hook.submit_run(self.json) - _handle_databricks_operator_execution(self, hook, self.log, context) + self.run_id = self._hook.submit_run(self.json) + _handle_databricks_operator_execution(self, self._hook, self.log, context) def on_kill(self): if self.run_id: - hook = self._get_hook() - hook.cancel_run(self.run_id) + self._hook.cancel_run(self.run_id) self.log.info( 'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id ) @@ -638,21 +638,22 @@ def __init__( # This variable will be used in case our task gets killed. self.run_id: Optional[int] = None self.do_xcom_push = do_xcom_push - self.hook = None - - def _get_hook(self, caller="DatabricksRunNowOperator") -> DatabricksHook: - if not self.hook: - self.hook = DatabricksHook( - self.databricks_conn_id, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - retry_args=self.databricks_retry_args, - caller=caller, - ) - return self.hook + + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksRunNowOperator") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) def execute(self, context: 'Context'): - hook = self._get_hook() + hook = self._hook if 'job_name' in self.json: job_id = hook.find_job_id_by_name(self.json['job_name']) if job_id is None: @@ -664,8 +665,7 @@ def execute(self, context: 'Context'): def on_kill(self): if self.run_id: - hook = self._get_hook() - hook.cancel_run(self.run_id) + self._hook.cancel_run(self.run_id) self.log.info( 'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id ) diff --git a/airflow/providers/databricks/operators/databricks_repos.py b/airflow/providers/databricks/operators/databricks_repos.py index 89d5f1d0dfd44..32100bcb34250 100644 --- a/airflow/providers/databricks/operators/databricks_repos.py +++ b/airflow/providers/databricks/operators/databricks_repos.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Optional, Sequence from urllib.parse import urlparse +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.databricks.hooks.databricks import DatabricksHook @@ -116,7 +117,8 @@ def __detect_repo_provider__(url): pass return provider - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, @@ -141,22 +143,21 @@ def execute(self, context: 'Context'): f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'" ) payload["path"] = self.repo_path - hook = self._get_hook() existing_repo_id = None if self.repo_path is not None: - existing_repo_id = hook.get_repo_by_path(self.repo_path) + existing_repo_id = self._hook.get_repo_by_path(self.repo_path) if existing_repo_id is not None and not self.ignore_existing_repo: raise AirflowException(f"Repo with path '{self.repo_path}' already exists") if existing_repo_id is None: - result = hook.create_repo(payload) + result = self._hook.create_repo(payload) repo_id = result["id"] else: repo_id = existing_repo_id # update repo if necessary if self.branch is not None: - hook.update_repo(str(repo_id), {'branch': str(self.branch)}) + self._hook.update_repo(str(repo_id), {'branch': str(self.branch)}) elif self.tag is not None: - hook.update_repo(str(repo_id), {'tag': str(self.tag)}) + self._hook.update_repo(str(repo_id), {'tag': str(self.tag)}) return repo_id @@ -214,7 +215,8 @@ def __init__( self.branch = branch self.tag = tag - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, @@ -223,9 +225,8 @@ def _get_hook(self) -> DatabricksHook: ) def execute(self, context: 'Context'): - hook = self._get_hook() if self.repo_path is not None: - self.repo_id = hook.get_repo_by_path(self.repo_path) + self.repo_id = self._hook.get_repo_by_path(self.repo_path) if self.repo_id is None: raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") if self.branch is not None: @@ -233,7 +234,7 @@ def execute(self, context: 'Context'): else: payload = {'tag': str(self.tag)} - result = hook.update_repo(str(self.repo_id), payload) + result = self._hook.update_repo(str(self.repo_id), payload) return result['head_commit_id'] @@ -280,7 +281,8 @@ def __init__( self.repo_path = repo_path self.repo_id = repo_id - def _get_hook(self) -> DatabricksHook: + @cached_property + def _hook(self) -> DatabricksHook: return DatabricksHook( self.databricks_conn_id, retry_limit=self.databricks_retry_limit, @@ -289,10 +291,9 @@ def _get_hook(self) -> DatabricksHook: ) def execute(self, context: 'Context'): - hook = self._get_hook() if self.repo_path is not None: - self.repo_id = hook.get_repo_by_path(self.repo_path) + self.repo_id = self._hook.get_repo_by_path(self.repo_path) if self.repo_id is None: raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'") - hook.delete_repo(str(self.repo_id)) + self._hook.delete_repo(str(self.repo_id))