-
Notifications
You must be signed in to change notification settings - Fork 180
/
Copy pathselector.py
551 lines (457 loc) · 20.5 KB
/
selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
from __future__ import annotations
import copy
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from cosmos.constants import DbtResourceType
from cosmos.exceptions import CosmosValueError
from cosmos.log import get_logger
if TYPE_CHECKING:
from cosmos.dbt.graph import DbtNode
SUPPORTED_CONFIG = ["materialized", "schema", "tags"]
PATH_SELECTOR = "path:"
TAG_SELECTOR = "tag:"
CONFIG_SELECTOR = "config."
PLUS_SELECTOR = "+"
AT_SELECTOR = "@"
GRAPH_SELECTOR_REGEX = r"^(@|[0-9]*\+)?([^\+]+)(\+[0-9]*)?$|"
logger = get_logger(__name__)
@dataclass
class GraphSelector:
"""
Implements dbt graph operator selectors:
model_a
+model_b
model_c+
+model_d+
2+model_e
model_f+3
@model_g
+/path/to/model_g+
path:/path/to/model_h+
+tag:nightly
+config.materialized:view
https://docs.getdbt.com/reference/node-selection/graph-operators
"""
node_name: str
precursors: str | None
descendants: str | None
at_operator: bool = False
@property
def precursors_depth(self) -> int:
"""
Calculates the depth/degrees/generations of precursors (parents).
Return:
-1: if it should return all the generations of precursors
0: if it shouldn't return any precursors
>0: upperbound number of parent generations
"""
if self.at_operator:
return -1
if not self.precursors:
return 0
if self.precursors == "+":
return -1
else:
return int(self.precursors[:-1])
@property
def descendants_depth(self) -> int:
"""
Calculates the depth/degrees/generations of descendants (children).
Return:
-1: if it should return all the generations of children
0: if it shouldn't return any children
>0: upperbound of children generations
"""
if not self.descendants:
return 0
if self.descendants == "+":
return -1
else:
return int(self.descendants[1:])
@staticmethod
def parse(text: str) -> GraphSelector | None:
"""
Parse a string and identify if there are graph selectors, including the desired node name, descendants and
precursors. Return a GraphSelector instance if the pattern matches.
"""
regex_match = re.search(GRAPH_SELECTOR_REGEX, text)
if regex_match:
precursors, node_name, descendants = regex_match.groups()
if "/" in node_name and not node_name.startswith(PATH_SELECTOR):
node_name = f"{PATH_SELECTOR}{node_name}"
at_operator = precursors == AT_SELECTOR
if at_operator:
precursors = None
descendants = "+" # @ implies all descendants
return GraphSelector(node_name, precursors, descendants, at_operator)
return None
def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
"""
Parse original nodes and add the precursor nodes related to this config to the selected_nodes set.
:param nodes: Original dbt nodes list
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where precursor nodes will be added to.
"""
if self.precursors or self.at_operator:
depth = self.precursors_depth
previous_generation = {root_id}
processed_nodes = set()
while depth and previous_generation:
new_generation: set[str] = set()
for node_id in previous_generation:
if node_id not in processed_nodes:
new_generation.update(set(nodes[node_id].depends_on))
processed_nodes.add(node_id)
selected_nodes.update(new_generation)
previous_generation = new_generation
depth -= 1
def select_node_descendants(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None:
"""
Parse original nodes and add the descendant nodes related to this config to the selected_nodes set.
:param nodes: Original dbt nodes list
:param root_id: Unique identifier of self.node_name
:param selected_nodes: Set where descendant nodes will be added to.
"""
if self.descendants:
children_by_node = defaultdict(set)
# Index nodes by parent id
# We could optimize by doing this only once for the dbt project and giving it
# as a parameter to the GraphSelector
for node_id, node in nodes.items():
for parent_id in node.depends_on:
children_by_node[parent_id].add(node_id)
depth = self.descendants_depth
previous_generation = {root_id}
processed_nodes = set()
while depth and previous_generation:
new_generation: set[str] = set()
for node_id in previous_generation:
if node_id not in processed_nodes:
new_generation.update(children_by_node[node_id])
processed_nodes.add(node_id)
selected_nodes.update(new_generation)
previous_generation = new_generation
depth -= 1
def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]:
"""
Given a dictionary with the original dbt project nodes, applies the current graph selector to
identify the subset of nodes that matches the selection criteria.
:param nodes: dbt project nodes
:return: set of node ids that matches current graph selector
"""
selected_nodes: set[str] = set()
root_nodes: set[str] = set()
# Index nodes by name, we can improve performance by doing this once
# for multiple GraphSelectors
if PATH_SELECTOR in self.node_name:
path_selection = self.node_name[len(PATH_SELECTOR) :]
root_nodes.update({node_id for node_id, node in nodes.items() if path_selection in str(node.file_path)})
elif TAG_SELECTOR in self.node_name:
tag_selection = self.node_name[len(TAG_SELECTOR) :]
root_nodes.update({node_id for node_id, node in nodes.items() if tag_selection in node.tags})
elif CONFIG_SELECTOR in self.node_name:
config_selection_key, config_selection_value = self.node_name[len(CONFIG_SELECTOR) :].split(":")
if config_selection_key not in SUPPORTED_CONFIG:
logger.warning("Unsupported config key selector: %s", config_selection_key)
# currently tags, materialized, and schema are the only supported config keys
# logic is separated into two conditions because the config 'tags' contains a
# list of tags, but the config 'materialized', and 'schema' contain strings
elif config_selection_key == "tags":
root_nodes.update(
{
node_id
for node_id, node in nodes.items()
if config_selection_value in node.config.get(config_selection_key, [])
}
)
elif config_selection_key in (
"materialized",
"schema",
):
root_nodes.update(
{
node_id
for node_id, node in nodes.items()
if config_selection_value == node.config.get(config_selection_key, "")
}
)
else:
node_by_name = {}
for node_id, node in nodes.items():
node_by_name[node.name] = node_id
if self.node_name in node_by_name:
root_id = node_by_name[self.node_name]
root_nodes.add(root_id)
else:
logger.warning(f"Selector {self.node_name} not found.")
return selected_nodes
selected_nodes.update(root_nodes)
self._select_nodes(nodes, root_nodes, selected_nodes)
return selected_nodes
def _select_nodes(self, nodes: dict[str, DbtNode], root_nodes: set[str], selected_nodes: set[str]) -> None:
"""
Handle selection of nodes based on the graph selector configuration.
:param nodes: dbt project nodes
:param root_nodes: Set of root node ids
:param selected_nodes: Set where selected nodes will be added to.
"""
if self.at_operator:
descendants: set[str] = set()
# First get all descendants
for root_id in root_nodes:
self.select_node_descendants(nodes, root_id, descendants)
selected_nodes.update(descendants)
# Get ancestors for root nodes and all descendants
for node_id in root_nodes | descendants:
self.select_node_precursors(nodes, node_id, selected_nodes)
else:
# Normal selection
for root_id in root_nodes:
self.select_node_precursors(nodes, root_id, selected_nodes)
self.select_node_descendants(nodes, root_id, selected_nodes)
class SelectorConfig:
"""
Represents a select/exclude statement.
Supports to load it from a string.
"""
def __init__(self, project_dir: Path | None, statement: str):
"""
Create a selector config file.
:param project_dir: Directory to a dbt project
:param statement: dbt statement as passed within select and exclude arguments
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
self.project_dir = project_dir
self.paths: list[Path] = []
self.tags: list[str] = []
self.config: dict[str, str] = {}
self.other: list[str] = []
self.graph_selectors: list[GraphSelector] = []
self.load_from_statement(statement)
@property
def is_empty(self) -> bool:
return not (self.paths or self.tags or self.config or self.graph_selectors or self.other)
def load_from_statement(self, statement: str) -> None:
"""
Load in-place select parameters.
Raises an exception if they are not yet implemented in Cosmos.
:param statement: dbt statement as passed within select and exclude arguments
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
items = statement.split(",")
for item in items:
regex_match = re.search(GRAPH_SELECTOR_REGEX, item)
if regex_match:
precursors, node_name, descendants = regex_match.groups()
if node_name is None:
...
elif precursors or descendants:
self._parse_unknown_selector(item)
elif node_name.startswith(PATH_SELECTOR):
self._parse_path_selector(item)
elif "/" in node_name:
self._parse_path_selector(f"{PATH_SELECTOR}{node_name}")
elif node_name.startswith(TAG_SELECTOR):
self._parse_tag_selector(item)
elif node_name.startswith(CONFIG_SELECTOR):
self._parse_config_selector(item)
else:
self._parse_unknown_selector(item)
def _parse_unknown_selector(self, item: str) -> None:
if item:
graph_selector = GraphSelector.parse(item)
if graph_selector is not None:
self.graph_selectors.append(graph_selector)
else:
self.other.append(item)
logger.warning("Unsupported select statement: %s", item)
def _parse_config_selector(self, item: str) -> None:
index = len(CONFIG_SELECTOR)
key, value = item[index:].split(":")
if key in SUPPORTED_CONFIG:
self.config[key] = value
def _parse_tag_selector(self, item: str) -> None:
index = len(TAG_SELECTOR)
self.tags.append(item[index:])
def _parse_path_selector(self, item: str) -> None:
index = len(PATH_SELECTOR)
if self.project_dir:
self.paths.append(self.project_dir / Path(item[index:]))
else:
self.paths.append(Path(item[index:]))
def __repr__(self) -> str:
return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other}, graph_selectors={self.graph_selectors})"
class NodeSelector:
"""
Class to select nodes based on a selector config.
:param nodes: Dictionary mapping dbt nodes (node.unique_id to node)
:param config: User-defined select statements
"""
def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None:
self.nodes = nodes
self.config = config
self.selected_nodes: set[str] = set()
@property
def select_nodes_ids_by_intersection(self) -> set[str]:
"""
Return a list of node ids which matches the configuration defined in config.
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
if self.config.is_empty:
return set(self.nodes.keys())
selected_nodes: set[str] = set()
self.visited_nodes: set[str] = set()
for node_id, node in self.nodes.items():
if self._should_include_node(node_id, node):
selected_nodes.add(node_id)
if self.config.graph_selectors:
nodes_by_graph_selector = self.select_by_graph_operator()
selected_nodes = selected_nodes.intersection(nodes_by_graph_selector)
self.selected_nodes = selected_nodes
return selected_nodes
def _should_include_node(self, node_id: str, node: DbtNode) -> bool:
"""Checks if a single node should be included. Only runs once per node with caching."""
logger.debug("Inspecting if the node <%s> should be included.", node_id)
if node_id in self.visited_nodes:
return node_id in self.selected_nodes
self.visited_nodes.add(node_id)
if node.resource_type == DbtResourceType.TEST and node.depends_on and len(node.depends_on) > 0:
node.tags = getattr(self.nodes.get(node.depends_on[0]), "tags", [])
logger.debug(
"The test node <%s> inherited these tags from the parent node <%s>: %s",
node_id,
node.depends_on[0],
node.tags,
)
if not self._is_tags_subset(node):
logger.debug("Excluding node <%s>", node_id)
return False
node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG}
if not self._is_config_subset(node_config):
return False
# Remove 'tags' as they've already been filtered for
config_copy = copy.deepcopy(self.config.config)
config_copy.pop("tags", None)
node_config.pop("tags", None)
if not (config_copy.items() <= node_config.items()):
return False
if self.config.paths and not self._is_path_matching(node):
return False
return True
def _is_tags_subset(self, node: DbtNode) -> bool:
"""Checks if the node's tags are a subset of the config's tags."""
if not (set(self.config.tags) <= set(node.tags)):
return False
return True
def _is_config_subset(self, node_config: dict[str, Any]) -> bool:
"""Checks if the node's config is a subset of the config's config."""
config_tags = self.config.config.get("tags")
if config_tags and config_tags not in node_config.get("tags", []):
return False
return True
def _is_path_matching(self, node: DbtNode) -> bool:
"""Checks if the node's path is a subset of the config's paths."""
for filter_path in self.config.paths:
if filter_path in node.file_path.parents or filter_path == node.file_path:
return True
# if it's a test coming from a schema.yml file, check the model's file_path
if node.resource_type == DbtResourceType.TEST and node.file_path.name == "schema.yml":
# try to get the corresponding model from node.depends_on
if len(node.depends_on) == 1:
model_node = self.nodes.get(node.depends_on[0])
if model_node:
return self._should_include_node(node.depends_on[0], model_node)
return False
def select_by_graph_operator(self) -> set[str]:
"""
Return a list of node ids which match the configuration defined in the config.
Return all nodes that are parents (or parents from parents) of the root defined in the configuration.
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
selected_nodes_by_selector: list[set[str]] = []
for graph_selector in self.config.graph_selectors:
selected_nodes_by_selector.append(graph_selector.filter_nodes(self.nodes))
return set.intersection(*selected_nodes_by_selector)
def retrieve_by_label(statement_list: list[str], label: str) -> set[str]:
"""
Return a set of values associated with a label.
Example:
>>> values = retrieve_by_label(["path:/tmp,tag:a", "tag:b,path:/home"])
>>> values
{"a", "b"}
"""
label_values: set[str] = set()
for statement in statement_list:
config = SelectorConfig(Path(), statement)
item_values = getattr(config, label)
label_values.update(item_values)
return label_values
def select_nodes(
project_dir: Path | None,
nodes: dict[str, DbtNode],
select: list[str] | None = None,
exclude: list[str] | None = None,
) -> dict[str, DbtNode]:
"""
Given a group of nodes within a project, apply select and exclude filters using
dbt node selection.
References:
https://docs.getdbt.com/reference/node-selection/syntax
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
select = select or []
exclude = exclude or []
if not select and not exclude:
return nodes
validate_filters(exclude, select)
subset_ids = apply_select_filter(nodes, project_dir, select)
if select:
nodes = get_nodes_from_subset(nodes, subset_ids)
exclude_ids = apply_exclude_filter(nodes, project_dir, exclude)
subset_ids = set(nodes.keys()) - exclude_ids
return get_nodes_from_subset(nodes, subset_ids)
def get_nodes_from_subset(nodes: dict[str, DbtNode], subset_ids: set[str]) -> dict[str, DbtNode]:
nodes = {id_: nodes[id_] for id_ in subset_ids}
return nodes
def apply_exclude_filter(nodes: dict[str, DbtNode], project_dir: Path | None, exclude: list[str]) -> set[str]:
exclude_ids: set[str] = set()
for statement in exclude:
config = SelectorConfig(project_dir, statement)
node_selector = NodeSelector(nodes, config)
exclude_ids.update(node_selector.select_nodes_ids_by_intersection)
return exclude_ids
def apply_select_filter(nodes: dict[str, DbtNode], project_dir: Path | None, select: list[str]) -> set[str]:
subset_ids: set[str] = set()
for statement in select:
config = SelectorConfig(project_dir, statement)
node_selector = NodeSelector(nodes, config)
select_ids = node_selector.select_nodes_ids_by_intersection
subset_ids.update(select_ids)
return subset_ids
def validate_filters(exclude: list[str], select: list[str]) -> None:
"""
Validate select and exclude filters.
"""
filters = [["select", select], ["exclude", exclude]]
for filter_type, filter in filters:
for filter_parameter in filter:
if (
filter_parameter.startswith(PATH_SELECTOR)
or filter_parameter.startswith(TAG_SELECTOR)
or PLUS_SELECTOR in filter_parameter
or any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG])
):
continue
elif ":" in filter_parameter:
raise CosmosValueError(f"Invalid {filter_type} filter: {filter_parameter}")