Skip to content

Commit

Permalink
Add support for the "at" operator
Browse files Browse the repository at this point in the history
Documentation: https://docs.getdbt.com/reference/node-selection/graph-operators#the-at-operator

Signed-off-by: Ben Girard <ben.girard@stillfront.com>
  • Loading branch information
benjy44 committed Dec 9, 2024
1 parent 3ff70d8 commit 8fc9a90
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 7 deletions.
48 changes: 41 additions & 7 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
TAG_SELECTOR = "tag:"
CONFIG_SELECTOR = "config."
PLUS_SELECTOR = "+"
GRAPH_SELECTOR_REGEX = r"^([0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"
AT_SELECTOR = "@"
GRAPH_SELECTOR_REGEX = r"^(@|[0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"

logger = get_logger(__name__)

Expand All @@ -35,6 +36,7 @@ class GraphSelector:
+model_d+
2+model_e
model_f+3
@model_g
+/path/to/model_g+
path:/path/to/model_h+
+tag:nightly
Expand All @@ -46,6 +48,7 @@ class GraphSelector:
node_name: str
precursors: str | None
descendants: str | None
at_operator: bool = False

@property
def precursors_depth(self) -> int:
Expand All @@ -56,6 +59,8 @@ def precursors_depth(self) -> int:
0: if it shouldn't return any precursors
>0: upperbound number of parent generations
"""
if self.at_operator:
return -1
if not self.precursors:
return 0
if self.precursors == "+":
Expand Down Expand Up @@ -90,7 +95,13 @@ def parse(text: str) -> GraphSelector | None:
precursors, node_name, descendants = regex_match.groups()
if "/" in node_name and not node_name.startswith(PATH_SELECTOR):
node_name = f"{PATH_SELECTOR}{node_name}"
return GraphSelector(node_name, precursors, descendants)

at_operator = precursors == AT_SELECTOR
if at_operator:
precursors = None
descendants = "+" # @ implies all descendants

return GraphSelector(node_name, precursors, descendants, at_operator)
return None

def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
Expand All @@ -101,7 +112,7 @@ def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, select
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where precursor nodes will be added to.
"""
if self.precursors:
if self.precursors or self.at_operator:
depth = self.precursors_depth
previous_generation = {root_id}
processed_nodes = set()
Expand Down Expand Up @@ -203,16 +214,39 @@ def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]:
root_id = node_by_name[self.node_name]
root_nodes.add(root_id)
else:
logger.warn(f"Selector {self.node_name} not found.")
logger.warning(f"Selector {self.node_name} not found.")
return selected_nodes

selected_nodes.update(root_nodes)

for root_id in root_nodes:
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)
self._select_nodes(nodes, root_nodes, selected_nodes)

return selected_nodes

def _select_nodes(self, nodes: dict[str, DbtNode], root_nodes: set[str], selected_nodes: set[str]) -> None:
"""
Handle selection of nodes based on the graph selector configuration.
:param nodes: dbt project nodes
:param root_nodes: Set of root node ids
:param selected_nodes: Set where selected nodes will be added to.
"""
if self.at_operator:
descendants: set[str] = set()
# First get all descendants
for root_id in root_nodes:
self.select_node_descendants(nodes, root_id, descendants)
selected_nodes.update(descendants)

# Get ancestors for root nodes and all descendants
for node_id in root_nodes | descendants:
self.select_node_precursors(nodes, node_id, selected_nodes)
else:
# Normal selection
for root_id in root_nodes:
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)


class SelectorConfig:
"""
Expand Down
12 changes: 12 additions & 0 deletions docs/configuration/selecting-excluding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The ``select`` and ``exclude`` parameters are lists, with values like the follow
- ``config.materialized:table``: include/exclude models with the config ``materialized: table``
- ``path:analytics/tables``: include/exclude models in the ``analytics/tables`` directory
- ``+node_name+1`` (graph operators): include/exclude the node with name ``node_name``, all its parents, and its first generation of children (`dbt graph selector docs <https://docs.getdbt.com/reference/node-selection/graph-operators>`_)
- ``@node_name`` (@ operator): include/exclude the node with name ``node_name``, all its descendants, and all ancestors of those descendants. This is useful in CI environments where you want to build a model and all its descendants, but you need the ancestors of those descendants to exist first.
- ``tag:my_tag,+node_name`` (intersection): include/exclude ``node_name`` and its parents if they have the tag ``my_tag`` (`dbt set operator docs <https://docs.getdbt.com/reference/node-selection/set-operators>`_)
- ``['tag:first_tag', 'tag:second_tag']`` (union): include/exclude nodes that have either ``tag:first_tag`` or ``tag:second_tag``

Expand Down Expand Up @@ -91,6 +92,17 @@ Examples:
)
)
.. code-block:: python
from cosmos import DbtDag, RenderConfig
jaffle_shop = DbtDag(
render_config=RenderConfig(
select=["@my_model"], # selects my_model, all its descendants,
# and all ancestors needed to build those descendants
)
)
Using ``selector``
--------------------------------
.. note::
Expand Down
80 changes: 80 additions & 0 deletions tests/dbt/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,83 @@ def test_should_include_node_without_depends_on(selector_config):
def test_select_using_graph_operators(select_statement, expected):
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=select_statement)
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator():
"""Test basic @ operator selecting node, descendants and ancestors of all"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@parent"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
"model.dbt-proj.sibling1",
"model.dbt-proj.sibling2",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_leaf_node():
"""Test @ operator on a leaf node (no descendants)"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@child"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_root_node():
"""Test @ operator on a root node (no ancestors)"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@grandparent"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
"model.dbt-proj.sibling1",
"model.dbt-proj.sibling2",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_union():
"""Test @ operator union with another selector"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@child", "tag:has_child"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_with_path():
"""Test @ operator with a path"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@gen2/models"])
expected = [
"model.dbt-proj.another_grandparent_node",
"model.dbt-proj.child",
"model.dbt-proj.grandparent",
"model.dbt-proj.parent",
"model.dbt-proj.sibling1",
"model.dbt-proj.sibling2",
]
assert sorted(selected.keys()) == expected


def test_select_nodes_by_at_operator_nonexistent_node():
"""Test @ operator with a node that doesn't exist"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["@nonexistent"])
expected = []
assert sorted(selected.keys()) == expected


def test_exclude_with_at_operator():
"""Test excluding nodes selected by @ operator"""
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["@parent"])
expected = ["model.dbt-proj.orphaned"]
assert sorted(selected.keys()) == expected

0 comments on commit 8fc9a90

Please sign in to comment.