diff --git a/databricks/sdk/mixins/compute.py b/databricks/sdk/mixins/compute.py index c840eb63..61733a6f 100644 --- a/databricks/sdk/mixins/compute.py +++ b/databricks/sdk/mixins/compute.py @@ -1,9 +1,15 @@ +import datetime +import logging import re +import time from dataclasses import dataclass from typing import Optional +from databricks.sdk.errors import OperationFailed from databricks.sdk.service import compute +_LOG = logging.getLogger('databricks.sdk') + @dataclass class SemVer: @@ -203,16 +209,28 @@ def select_node_type(self, return nt.node_type_id raise ValueError("cannot determine smallest node type") - def ensure_cluster_is_running(self, cluster_id: str): + def ensure_cluster_is_running(self, cluster_id: str) -> None: """Ensures that given cluster is running, regardless of the current state""" - state = compute.State - info = self.get(cluster_id) - if info.state == state.TERMINATED: - self.start(cluster_id).result() - elif info.state == state.TERMINATING: - self.wait_get_cluster_terminated(cluster_id) - self.start(cluster_id).result() - elif info.state in (state.PENDING, state.RESIZING, state.RESTARTING): - self.wait_get_cluster_running(cluster_id) - elif info.state in (state.ERROR, state.UNKNOWN): - raise RuntimeError(f'Cluster {info.cluster_name} is {info.state}: {info.state_message}') + timeout = datetime.timedelta(minutes=20) + deadline = time.time() + timeout.total_seconds() + while time.time() < deadline: + try: + state = compute.State + info = self.get(cluster_id) + if info.state == state.RUNNING: + return + elif info.state == state.TERMINATED: + self.start(cluster_id).result() + return + elif info.state == state.TERMINATING: + self.wait_get_cluster_terminated(cluster_id) + self.start(cluster_id).result() + return + elif info.state in (state.PENDING, state.RESIZING, state.RESTARTING): + self.wait_get_cluster_running(cluster_id) + return + elif info.state in (state.ERROR, state.UNKNOWN): + raise RuntimeError(f'Cluster {info.cluster_name} is {info.state}: {info.state_message}') + except OperationFailed as e: + _LOG.debug('Operation failed, retrying', exc_info=e) + raise TimeoutError(f'timed out after {timeout}')