diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index 35ae70ccfc..91611c2adb 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -3,7 +3,7 @@ from abc import ABC from collections.abc import Iterable -from astroid import Call, Const, InferenceError, JoinedStr, NodeNG # type: ignore +from astroid import Attribute, Call, Const, InferenceError, JoinedStr, Name, NodeNG # type: ignore from sqlglot import Expression as SqlExpression, parse as parse_sql, ParseError as SqlParseError from sqlglot.expressions import Alter, Create, Delete, Drop, Identifier, Insert, Literal, Select @@ -94,21 +94,24 @@ def visit_const(self, node: Const): if isinstance(node.value, str): self._check_str_constant(node, InferredValue([node])) - def _check_str_constant(self, source_node, inferred: InferredValue): + def _check_str_constant(self, source_node: NodeNG, inferred: InferredValue): if self._already_reported(source_node, inferred): return # don't report on JoinedStr fragments if isinstance(source_node.parent, JoinedStr): return - # avoid duplicate advices that are reported by SparkSqlPyLinter - if self._prevent_spark_duplicates and Tree(source_node).is_from_module("spark"): - return value = inferred.as_string() for pattern in DIRECT_FS_ACCESS_PATTERNS: if not pattern.matches(value): continue + # avoid false positives with relative URLs + if self._is_http_call_parameter(source_node): + return + # avoid duplicate advices that are reported by SparkSqlPyLinter + if self._prevent_spark_duplicates and Tree(source_node).is_from_module("spark"): + return # since we're normally filtering out spark calls, we're dealing with dfsas we know little about - # notable we don't know is_read or is_write + # notably we don't know is_read or is_write dfsa = DirectFsAccess( path=value, is_read=True, @@ -117,6 +120,33 @@ def _check_str_constant(self, source_node, inferred: InferredValue): self._directfs_nodes.append(DirectFsAccessNode(dfsa, source_node)) self._reported_locations.add((source_node.lineno, source_node.col_offset)) + @classmethod + def _is_http_call_parameter(cls, source_node: NodeNG): + if not isinstance(source_node.parent, Call): + return False + # for now we only cater for ws.api_client.do + return cls._is_ws_api_client_do_call(source_node) + + @classmethod + def _is_ws_api_client_do_call(cls, source_node: NodeNG): + assert isinstance(source_node.parent, Call) + func = source_node.parent.func + if not isinstance(func, Attribute) or func.attrname != "do": + return False + expr = func.expr + if not isinstance(expr, Attribute) or expr.attrname != "api_client": + return False + expr = expr.expr + if not isinstance(expr, Name): + return False + for value in InferredValue.infer_from_node(expr): + if not value.is_inferred(): + continue + for node in value.nodes: + return Tree(node).is_instance_of("WorkspaceClient") + # at this point is seems safer to assume that expr.expr is a workspace than the opposite + return True + def _already_reported(self, source_node: NodeNG, inferred: InferredValue): all_nodes = [source_node] + inferred.nodes return any((node.lineno, node.col_offset) in self._reported_locations for node in all_nodes) diff --git a/src/databricks/labs/ucx/source_code/python/python_ast.py b/src/databricks/labs/ucx/source_code/python/python_ast.py index 6588383aca..5bf6781374 100644 --- a/src/databricks/labs/ucx/source_code/python/python_ast.py +++ b/src/databricks/labs/ucx/source_code/python/python_ast.py @@ -11,14 +11,17 @@ AssignName, Attribute, Call, + ClassDef, Const, Expr, Import, ImportFrom, + Instance, Module, Name, NodeNG, parse, + Uninferable, ) logger = logging.getLogger(__name__) @@ -160,6 +163,16 @@ def append_nodes(self, nodes: list[NodeNG]) -> None: node.parent = self_module self_module.body.append(node) + def is_instance_of(self, class_name: str) -> bool: + for inferred in self.node.inferred(): + if inferred is Uninferable: + continue + if not isinstance(inferred, (Const, Instance)): + return False + proxied = getattr(inferred, "_proxied", None) + return isinstance(proxied, ClassDef) and proxied.name == class_name + return False + def is_from_module(self, module_name: str) -> bool: return self._is_from_module(module_name, set()) diff --git a/src/databricks/labs/ucx/source_code/python/python_infer.py b/src/databricks/labs/ucx/source_code/python/python_infer.py index 2ed7929260..b1d79da641 100644 --- a/src/databricks/labs/ucx/source_code/python/python_infer.py +++ b/src/databricks/labs/ucx/source_code/python/python_infer.py @@ -11,6 +11,7 @@ decorators, Dict, FormattedValue, + Instance, JoinedStr, Name, NodeNG, @@ -68,20 +69,28 @@ def _infer_values(cls, node: NodeNG) -> Iterator[Iterable[NodeNG]]: elif isinstance(node, FormattedValue): yield from _LocalInferredValue.do_infer_values(node.value) else: - yield from cls._infer_internal(node) + yield from cls._safe_infer_internal(node) @classmethod - def _infer_internal(cls, node: NodeNG): + def _safe_infer_internal(cls, node: NodeNG): try: - for inferred in node.inferred(): - # work around infinite recursion of empty lists - if inferred == node: - continue - yield from _LocalInferredValue.do_infer_values(inferred) + yield from cls._unsafe_infer_internal(node) except InferenceError as e: logger.debug(f"When inferring {node}", exc_info=e) yield [Uninferable] + @classmethod + def _unsafe_infer_internal(cls, node: NodeNG): + all_inferred = node.inferred() + if len(all_inferred) == 0 and isinstance(node, Instance): + yield [node] + return + for inferred in all_inferred: + # work around infinite recursion of empty lists + if inferred == node: + continue + yield from _LocalInferredValue.do_infer_values(inferred) + @classmethod def _infer_values_from_joined_string(cls, node: NodeNG) -> Iterator[Iterable[NodeNG]]: assert isinstance(node, JoinedStr) diff --git a/tests/unit/source_code/linters/test_directfs.py b/tests/unit/source_code/linters/test_directfs.py index abfa78a8e8..2e482d8bb0 100644 --- a/tests/unit/source_code/linters/test_directfs.py +++ b/tests/unit/source_code/linters/test_directfs.py @@ -56,6 +56,13 @@ def test_detects_dfsa_paths(code, expected): ('spark.read.parquet("dbfs://mnt/foo/bar")', 1), ('DBFS="dbfs:/mnt/foo/bar"; spark.read.parquet(DBFS)', 1), ('a=f"/Repos/{thing1}/sdk-{thing2}-{thing3}"', 0), + ( + """from databricks.sdk import WorkspaceClient +ws = WorkspaceClient() +ws.api_client.do("DELETE", "/api/2.0/feature-store/feature-tables/delete", body={"name": table["name"]}) +""", + 0, + ), ], ) def test_directfs_linter(code, expected): diff --git a/tests/unit/source_code/python/test_python_ast.py b/tests/unit/source_code/python/test_python_ast.py index c80abb5ceb..7d9f27e73e 100644 --- a/tests/unit/source_code/python/test_python_ast.py +++ b/tests/unit/source_code/python/test_python_ast.py @@ -1,5 +1,5 @@ import pytest -from astroid import Assign, AstroidSyntaxError, Attribute, Call, Const, Expr, Name # type: ignore +from astroid import Assign, AstroidSyntaxError, Attribute, Call, Const, Expr, Module, Name # type: ignore from databricks.labs.ucx.source_code.python.python_ast import Tree, TreeHelper from databricks.labs.ucx.source_code.python.python_infer import InferredValue @@ -161,6 +161,16 @@ def test_is_from_module(): assert Tree(save_call).is_from_module("spark") +@pytest.mark.parametrize("source, name, class_name", [("a = 123", "a", "int")]) +def test_is_instance_of(source, name, class_name): + tree = Tree.normalize_and_parse(source) + assert isinstance(tree.node, Module) + module = tree.node + var = module.globals.get(name, None) + assert isinstance(var, list) and len(var) > 0 + assert Tree(var[0]).is_instance_of(class_name) + + def test_supports_recursive_refs_when_checking_module(): source_1 = """ df = spark.read.csv("hi")