Skip to content

Commit

Permalink
Fix mypy errors by using cached_property instead of explicit caching
Browse files Browse the repository at this point in the history
  • Loading branch information
alexott authored and potiuk committed Jul 25, 2022
1 parent 83d2881 commit e7697ec
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 46 deletions.
64 changes: 32 additions & 32 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from logging import Logger
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
Expand Down Expand Up @@ -382,28 +383,27 @@ def __init__(
# This variable will be used in case our task gets killed.
self.run_id: Optional[int] = None
self.do_xcom_push = do_xcom_push
self.hook = None

def _get_hook(self, caller="DatabricksSubmitRunOperator") -> DatabricksHook:
if not self.hook:
self.hook = DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)
return self.hook

@cached_property
def _hook(self):
return self._get_hook(caller="DatabricksSubmitRunOperator")

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)

def execute(self, context: 'Context'):
hook = self._get_hook()
self.run_id = hook.submit_run(self.json)
_handle_databricks_operator_execution(self, hook, self.log, context)
self.run_id = self._hook.submit_run(self.json)
_handle_databricks_operator_execution(self, self._hook, self.log, context)

def on_kill(self):
if self.run_id:
hook = self._get_hook()
hook.cancel_run(self.run_id)
self._hook.cancel_run(self.run_id)
self.log.info(
'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id
)
Expand Down Expand Up @@ -638,21 +638,22 @@ def __init__(
# This variable will be used in case our task gets killed.
self.run_id: Optional[int] = None
self.do_xcom_push = do_xcom_push
self.hook = None

def _get_hook(self, caller="DatabricksRunNowOperator") -> DatabricksHook:
if not self.hook:
self.hook = DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)
return self.hook

@cached_property
def _hook(self):
return self._get_hook(caller="DatabricksRunNowOperator")

def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)

def execute(self, context: 'Context'):
hook = self._get_hook()
hook = self._hook
if 'job_name' in self.json:
job_id = hook.find_job_id_by_name(self.json['job_name'])
if job_id is None:
Expand All @@ -664,8 +665,7 @@ def execute(self, context: 'Context'):

def on_kill(self):
if self.run_id:
hook = self._get_hook()
hook.cancel_run(self.run_id)
self._hook.cancel_run(self.run_id)
self.log.info(
'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id
)
Expand Down
29 changes: 15 additions & 14 deletions airflow/providers/databricks/operators/databricks_repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING, Optional, Sequence
from urllib.parse import urlparse

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.databricks.hooks.databricks import DatabricksHook
Expand Down Expand Up @@ -116,7 +117,8 @@ def __detect_repo_provider__(url):
pass
return provider

def _get_hook(self) -> DatabricksHook:
@cached_property
def _hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
Expand All @@ -141,22 +143,21 @@ def execute(self, context: 'Context'):
f"repo_path should have form of /Repos/{{folder}}/{{repo-name}}, got '{self.repo_path}'"
)
payload["path"] = self.repo_path
hook = self._get_hook()
existing_repo_id = None
if self.repo_path is not None:
existing_repo_id = hook.get_repo_by_path(self.repo_path)
existing_repo_id = self._hook.get_repo_by_path(self.repo_path)
if existing_repo_id is not None and not self.ignore_existing_repo:
raise AirflowException(f"Repo with path '{self.repo_path}' already exists")
if existing_repo_id is None:
result = hook.create_repo(payload)
result = self._hook.create_repo(payload)
repo_id = result["id"]
else:
repo_id = existing_repo_id
# update repo if necessary
if self.branch is not None:
hook.update_repo(str(repo_id), {'branch': str(self.branch)})
self._hook.update_repo(str(repo_id), {'branch': str(self.branch)})
elif self.tag is not None:
hook.update_repo(str(repo_id), {'tag': str(self.tag)})
self._hook.update_repo(str(repo_id), {'tag': str(self.tag)})

return repo_id

Expand Down Expand Up @@ -214,7 +215,8 @@ def __init__(
self.branch = branch
self.tag = tag

def _get_hook(self) -> DatabricksHook:
@cached_property
def _hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
Expand All @@ -223,17 +225,16 @@ def _get_hook(self) -> DatabricksHook:
)

def execute(self, context: 'Context'):
hook = self._get_hook()
if self.repo_path is not None:
self.repo_id = hook.get_repo_by_path(self.repo_path)
self.repo_id = self._hook.get_repo_by_path(self.repo_path)
if self.repo_id is None:
raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")
if self.branch is not None:
payload = {'branch': str(self.branch)}
else:
payload = {'tag': str(self.tag)}

result = hook.update_repo(str(self.repo_id), payload)
result = self._hook.update_repo(str(self.repo_id), payload)
return result['head_commit_id']


Expand Down Expand Up @@ -280,7 +281,8 @@ def __init__(
self.repo_path = repo_path
self.repo_id = repo_id

def _get_hook(self) -> DatabricksHook:
@cached_property
def _hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
Expand All @@ -289,10 +291,9 @@ def _get_hook(self) -> DatabricksHook:
)

def execute(self, context: 'Context'):
hook = self._get_hook()
if self.repo_path is not None:
self.repo_id = hook.get_repo_by_path(self.repo_path)
self.repo_id = self._hook.get_repo_by_path(self.repo_path)
if self.repo_id is None:
raise AirflowException(f"Can't find Repo ID for path '{self.repo_path}'")

hook.delete_repo(str(self.repo_id))
self._hook.delete_repo(str(self.repo_id))

0 comments on commit e7697ec

Please sign in to comment.