Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #7785: fail-fast behavior #8066

Merged
merged 10 commits into from
Jul 11, 2023
Merged
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230710-172547.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix fail-fast behavior (including retry)
time: 2023-07-10T17:25:47.912129-05:00
custom:
Author: aranke
Issue: "7785"
16 changes: 16 additions & 0 deletions core/dbt/contracts/results.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading

from dbt.contracts.graph.unparsed import FreshnessThreshold
from dbt.contracts.graph.nodes import SourceDefinition, ResultNode
from dbt.contracts.util import (
Expand Down Expand Up @@ -161,6 +163,20 @@ class RunResult(NodeResult):
def skipped(self):
return self.status == RunStatus.Skipped

@classmethod
def from_node(cls, node: ResultNode, status: RunStatus, message: Optional[str]):
thread_id = threading.current_thread().name
return RunResult(
status=status,
thread_id=thread_id,
execution_time=0,
timing=[],
message=message,
node=node,
adapter_response={},
failures=None,
)


@dataclass
class ExecutionResult(dbtClassMixin):
Expand Down
15 changes: 1 addition & 14 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,19 +296,6 @@ def from_run_result(self, result, start_time, timing_info):
failures=result.failures,
)

def skip_result(self, node, message):
thread_id = threading.current_thread().name
return RunResult(
status=RunStatus.Skipped,
thread_id=thread_id,
execution_time=0,
timing=[],
message=message,
node=node,
adapter_response={},
failures=None,
)

def compile_and_execute(self, manifest, ctx):
result = None
with self.adapter.connection_for(self.node) if get_flags().INTROSPECT else nullcontext():
Expand Down Expand Up @@ -483,7 +470,7 @@ def on_skip(self):
)
)

node_result = self.skip_result(self.node, error_message)
node_result = RunResult.from_node(self.node, RunStatus.Skipped, error_message)
return node_result

def do_skip(self, cause=None):
Expand Down
75 changes: 42 additions & 33 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import os
import time
from pathlib import Path
from abc import abstractmethod
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from pathlib import Path
from typing import Optional, Dict, List, Set, Tuple, Iterable, AbstractSet

from .printer import (
print_run_result_error,
print_run_end_messages,
)

from dbt.task.base import ConfiguredTask
import dbt.exceptions
import dbt.tracking
import dbt.utils
from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter
from dbt.logger import (
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
from dbt.contracts.graph.manifest import WritableManifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import (
NodeStatus,
RunExecutionResult,
RunningStatus,
RunResult,
RunStatus,
)
from dbt.contracts.state import PreviousState
from dbt.events.contextvars import log_contextvars, task_contextvars
from dbt.events.functions import fire_event, warn_or_error
from dbt.events.types import (
Formatting,
Expand All @@ -36,25 +35,29 @@
EndRunResult,
NothingToDo,
)
from dbt.events.contextvars import log_contextvars, task_contextvars
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import NodeStatus, RunExecutionResult, RunningStatus
from dbt.contracts.state import PreviousState
from dbt.exceptions import (
DbtInternalError,
NotImplementedError,
DbtRuntimeError,
FailFastError,
)

from dbt.flags import get_flags
from dbt.graph import GraphQueue, NodeSelector, SelectionSpec, parse_difference
from dbt.logger import (
DbtProcessState,
TextOnly,
UniqueID,
TimestampNamed,
DbtModelState,
ModelMetadata,
NodeCount,
)
from dbt.parser.manifest import write_manifest
import dbt.tracking

import dbt.exceptions
from dbt.flags import get_flags
import dbt.utils
from dbt.contracts.graph.manifest import WritableManifest
from dbt.task.base import ConfiguredTask
from .printer import (
print_run_result_error,
print_run_end_messages,
)

RESULT_FILE_NAME = "run_results.json"
RUNNING_STATE = DbtProcessState("running")
Expand Down Expand Up @@ -360,21 +363,27 @@ def execute_nodes(self):
pool = ThreadPool(num_threads)
try:
self.run_queue(pool)

except FailFastError as failure:
self._cancel_connections(pool)

executed_node_ids = [r.node.unique_id for r in self.node_results]

for r in self._flattened_nodes:
if r.unique_id not in executed_node_ids:
self.node_results.append(
RunResult.from_node(r, RunStatus.Skipped, "Skipping due to fail_fast")
)

print_run_result_error(failure.result)
raise

except KeyboardInterrupt:
self._cancel_connections(pool)
print_run_end_messages(self.node_results, keyboard_interrupt=True)
raise

pool.close()
pool.join()

return self.node_results
finally:
pool.close()
pool.join()
return self.node_results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why return is also moved into the finally block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return is moved to the finally block, otherwise only the current node is returned, but information about all the other node results is lost.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe finally blocks are executed whether an exception is raised or not


def _mark_dependent_errors(self, node_id, result, cause):
if self.graph is None:
Expand Down
6 changes: 4 additions & 2 deletions tests/functional/dependencies/test_local_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dbt.semver
import dbt.config
import dbt.exceptions
from dbt.contracts.results import RunStatus

from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture

Expand Down Expand Up @@ -207,8 +208,9 @@ def models(self):

def test_missing_dependency(self, project):
# dbt should raise a runtime exception
with pytest.raises(dbt.exceptions.DbtRuntimeError):
run_dbt(["compile"])
res = run_dbt(["compile"], expect_pass=False)
assert len(res) == 1
assert res[0].status == RunStatus.Error


class TestSimpleDependencyWithSchema(BaseDependencyTest):
Expand Down
17 changes: 10 additions & 7 deletions tests/functional/fail_fast/test_fail_fast_run.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pytest
import json
from pathlib import Path

import pytest

from dbt.contracts.results import RunResult
from dbt.tests.util import run_dbt


models__one_sql = """
select 1
"""
Expand All @@ -30,8 +28,11 @@ def test_fail_fast_run(
models, # noqa: F811
):
res = run_dbt(["run", "--fail-fast", "--threads", "1"], expect_pass=False)
# a RunResult contains only one node so we can be sure only one model was run
assert type(res) == RunResult
assert {r.node.unique_id: r.status for r in res.results} == {
"model.test.one": "success",
"model.test.two": "error",
}

run_results_file = Path(project.project_root) / "target/run_results.json"
assert run_results_file.is_file()
with run_results_file.open() as run_results_str:
Expand All @@ -57,5 +58,7 @@ def test_fail_fast_run_user_config(
models, # noqa: F811
):
res = run_dbt(["run", "--threads", "1"], expect_pass=False)
# a RunResult contains only one node so we can be sure only one model was run
assert type(res) == RunResult
assert {r.node.unique_id: r.status for r in res.results} == {
"model.test.one": "success",
"model.test.two": "error",
}
70 changes: 54 additions & 16 deletions tests/functional/retry/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,6 @@ def test_run_operation(self, project):
results = run_dbt(["retry"], expect_pass=False)
assert {n.unique_id: n.status for n in results.results} == expected_statuses

def test_fail_fast(self, project):
result = run_dbt(["--warn-error", "build", "--fail-fast"], expect_pass=False)

assert result.status == RunStatus.Error
assert result.node.name == "sample_model"

results = run_dbt(["retry"], expect_pass=False)

assert len(results.results) == 1
assert results.results[0].status == RunStatus.Error
assert results.results[0].node.name == "sample_model"

result = run_dbt(["retry", "--fail-fast"], expect_pass=False)
assert result.status == RunStatus.Error
assert result.node.name == "sample_model"

def test_removed_file(self, project):
run_dbt(["build"], expect_pass=False)

Expand All @@ -180,3 +164,57 @@ def test_removed_file_leaf_node(self, project):
rm_file("models", "third_model.sql")
with pytest.raises(ValueError, match="Couldn't find model 'model.test.third_model'"):
run_dbt(["retry"], expect_pass=False)


class TestFailFast:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": models__sample_model,
"second_model.sql": models__second_model,
"union_model.sql": models__union_model,
"final_model.sql": "select * from {{ ref('union_model') }};",
}

def test_fail_fast(self, project):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the functional test here!

results = run_dbt(["--fail-fast", "build"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.sample_model": RunStatus.Error,
"model.test.second_model": RunStatus.Success,
"model.test.union_model": RunStatus.Skipped,
"model.test.final_model": RunStatus.Skipped,
}

# Check that retry inherits fail-fast from upstream command (build)
results = run_dbt(["retry"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.sample_model": RunStatus.Error,
"model.test.union_model": RunStatus.Skipped,
"model.test.final_model": RunStatus.Skipped,
}

fixed_sql = "select 1 as id, 1 as foo"
write_file(fixed_sql, "models", "sample_model.sql")

results = run_dbt(["retry"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.sample_model": RunStatus.Success,
"model.test.union_model": RunStatus.Success,
"model.test.final_model": RunStatus.Error,
}

results = run_dbt(["retry"], expect_pass=False)
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.final_model": RunStatus.Error,
}

fixed_sql = "select * from {{ ref('union_model') }}"
write_file(fixed_sql, "models", "final_model.sql")

results = run_dbt(["retry"])
assert {r.node.unique_id: r.status for r in results.results} == {
"model.test.final_model": RunStatus.Success,
}

results = run_dbt(["retry"])
assert {r.node.unique_id: r.status for r in results.results} == {}