diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 2a3a7438aa..ee03fd9f2f 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -625,26 +625,38 @@ def remove_field(response: Response, *args, **kwargs) -> Response: def _extract_expressions( - template_string: Union[str, Dict], prefix: str = "" + template: Union[str, Dict], + prefix: str = "", ) -> List[str]: """Takes a template string and extracts expressions that start with a prefix. Args: - template_string (str): A string with expressions to extract + template (str): A string with expressions to extract prefix (str): A string that marks the beginning of an expression Example: >>> _extract_expressions("blog/{resources.blog.id}/comments", "resources.") ["resources.blog.id"] """ - # to use a dict with Formatter.parse we need to add curly brackets - if isinstance(template_string, dict): - template_string = "{" + json.dumps(template_string) + "}" - - return [ - field_name - for _, field_name, _, _ in string.Formatter().parse(template_string) - if field_name and field_name.startswith(prefix) - ] + expressions = set() + + def recursive_search(value): + if isinstance(value, dict): + for key, val in value.items(): + recursive_search(key) + recursive_search(val) + elif isinstance(value, list): + for item in value: + recursive_search(item) + elif isinstance(value, str): + e = [ + field_name + for _, field_name, _, _ in string.Formatter().parse(value) + if field_name and field_name.startswith(prefix) + ] + expressions.update(e) + + recursive_search(template) + return list(expressions) def _expressions_to_resolved_params(expressions: List[str]) -> List[ResolvedParam]: diff --git a/tests/sources/rest_api/conftest.py b/tests/sources/rest_api/conftest.py index ba4bb6880f..ed09e59171 100644 --- a/tests/sources/rest_api/conftest.py +++ b/tests/sources/rest_api/conftest.py @@ -195,10 +195,16 @@ def posts_with_results_key(request, context): @router.post(r"/posts/search_by_id$") def search_posts_by_id(request, context): body = request.json() - print(body) post_id = body.get("post_id", 0) - print(post_id) - return {"id": int(post_id), "body": f"Post body {post_id}"} + title = body.get("more", {}).get("title", 0) + + more_array = body.get("more_array", [])[0] + return { + "id": int(post_id), + "title": title, + "body": f"Post body {post_id}", + "more": f"More is equale to id: {more_array}", + } @router.post(r"/posts/search$") def search_posts(request, context): diff --git a/tests/sources/rest_api/integration/test_offline.py b/tests/sources/rest_api/integration/test_offline.py index 77b6f6c5af..7c380fc5a2 100644 --- a/tests/sources/rest_api/integration/test_offline.py +++ b/tests/sources/rest_api/integration/test_offline.py @@ -108,7 +108,7 @@ def test_load_mock_api_with_query_params(mock_api_server): full_refresh=True, ) - mock_source = rest_api_source( + mock_source: RESTAPIConfig = rest_api_source( { "client": {"base_url": "https://api.example.com"}, "resources": [ @@ -244,14 +244,14 @@ def test_load_mock_api_with_json_resolved_with_implicit_param(mock_api_server): "method": "POST", "json": { "post_id": "{resources.posts.id}", + "limit": 5, + "more": { + "title": "{resources.posts.title}", + }, + "more_array": [ + "{resources.posts.id}", + ], }, - # "params": { - # "posts__id": { - # "type": "resolve", - # "resource": "posts", - # "field": "id", - # } - # }, }, }, ], @@ -287,6 +287,18 @@ def test_load_mock_api_with_json_resolved_with_implicit_param(mock_api_server): [f"Post body {i}" for i in range(25)], ) + assert_query_data( + pipeline, + f"SELECT title FROM {posts_details_table} ORDER BY id limit 25", + [f"Post {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT more FROM {posts_details_table} ORDER BY id limit 25", + [f"More is equale to id: {i}" for i in range(25)], + ) + def test_source_with_post_request(mock_api_server): class JSONBodyPageCursorPaginator(BaseReferencePaginator):