diff --git a/pangeo_forge_runner/recipe_rewriter.py b/pangeo_forge_runner/recipe_rewriter.py index aa3b3817..5f6eb924 100644 --- a/pangeo_forge_runner/recipe_rewriter.py +++ b/pangeo_forge_runner/recipe_rewriter.py @@ -114,20 +114,46 @@ def _make_injected_get( keywords=[], ) + def inject_keywords(self, node: Call) -> Call: + """Inject keywords into calls.""" + for name, params in self.callable_args_injections.items(): + if hasattr(node.func, "id") and name == node.func.id: + # this is a non-chained call, so append to top-level `.keywords` + node.keywords += [ + keyword( + arg=k, + value=self._make_injected_get( + "_CALLABLE_ARGS_INJECTIONS", name, k + ), + ) + for k in params + ] + + elif hasattr(node.func, "value") and name == node.func.value.func.id: + # this is a *chained* call, so append to `.func.value.keywords` + node.func.value.keywords += [ + keyword( + arg=k, + value=self._make_injected_get( + "_CALLABLE_ARGS_INJECTIONS", name, k + ), + ) + for k in params + ] + return node + def visit_Call(self, node: Call) -> Call: """ Rewrite calls that return a FilePattern if we need to prune them """ if isinstance(node.func, Attribute): - # FIXME: Support it being imported as from apache_beam import Create too - if "apache_beam" not in self._import_aliases.values(): + if ( + # FIXME: Support it being imported as from apache_beam import Create too # if beam hasn't been imported, don't rewrite anything - return node - - # Only rewrite parameters to apache_beam.Create, regardless - # of how it is imported as - if node.func.attr == "Create" and ( - self._import_aliases.get(node.func.value.id) == "apache_beam" + "apache_beam" in self._import_aliases.values() + # Rewrite parameters to apache_beam.Create, regardless of how it is imported + and node.func.attr == "Create" + and self._import_aliases.get(node.func.value.id) == "apache_beam" ): # If there is a single argument pased to beam.Create, and it is .items() # This is the heurestic we use for figuring out that we are in fact operating on a FilePattern object @@ -137,19 +163,11 @@ def visit_Call(self, node: Call) -> Call: and node.args[0].func.attr == "items" ): return fix_missing_locations(self.transform_prune(node)) + elif node.func.attr == "with_resource_hints": + return fix_missing_locations(self.inject_keywords(node)) + elif isinstance(node.func, Name): # FIXME: Support importing in other ways - for name, params in self.callable_args_injections.items(): - if name == node.func.id: - node.keywords += [ - keyword( - arg=k, - value=self._make_injected_get( - "_CALLABLE_ARGS_INJECTIONS", name, k - ), - ) - for k in params - ] - return fix_missing_locations(node) + return fix_missing_locations(self.inject_keywords(node)) return node diff --git a/tests/rewriter-tests/callable-args-injection-chained/original.py b/tests/rewriter-tests/callable-args-injection-chained/original.py new file mode 100644 index 00000000..d6902534 --- /dev/null +++ b/tests/rewriter-tests/callable-args-injection-chained/original.py @@ -0,0 +1,5 @@ +def some_callable(some_argument): + pass + + +some_callable().with_resource_hints() diff --git a/tests/rewriter-tests/callable-args-injection-chained/params.py b/tests/rewriter-tests/callable-args-injection-chained/params.py new file mode 100644 index 00000000..8b50eacc --- /dev/null +++ b/tests/rewriter-tests/callable-args-injection-chained/params.py @@ -0,0 +1,4 @@ +# Parameters to be passed to RecipeRewriter constructor +params = dict( + prune=False, callable_args_injections={"some_callable": {"some_argument": 42}} +) diff --git a/tests/rewriter-tests/callable-args-injection-chained/rewritten.py b/tests/rewriter-tests/callable-args-injection-chained/rewritten.py new file mode 100644 index 00000000..95db9ad6 --- /dev/null +++ b/tests/rewriter-tests/callable-args-injection-chained/rewritten.py @@ -0,0 +1,9 @@ +def some_callable(some_argument): + pass + + +some_callable( + some_argument=_CALLABLE_ARGS_INJECTIONS.get("some_callable", {}).get( # noqa + "some_argument" + ) +).with_resource_hints()