diff --git a/newrelic/api/database_trace.py b/newrelic/api/database_trace.py index 2bc4976887..1069be506f 100644 --- a/newrelic/api/database_trace.py +++ b/newrelic/api/database_trace.py @@ -16,7 +16,7 @@ import logging from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.database_node import DatabaseNode from newrelic.core.stack_trace import current_stack @@ -244,9 +244,9 @@ def create_node(self): ) -def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None): +def DatabaseTraceWrapper(wrapped, sql, dbapi2_module=None, async_wrapper=None): def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -273,9 +273,9 @@ def _nr_database_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_database_trace_wrapper_) -def database_trace(sql, dbapi2_module=None): - return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module) +def database_trace(sql, dbapi2_module=None, async_wrapper=None): + return functools.partial(DatabaseTraceWrapper, sql=sql, dbapi2_module=dbapi2_module, async_wrapper=async_wrapper) -def wrap_database_trace(module, object_path, sql, dbapi2_module=None): - wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module)) +def wrap_database_trace(module, object_path, sql, dbapi2_module=None, async_wrapper=None): + wrap_object(module, object_path, DatabaseTraceWrapper, (sql, dbapi2_module, async_wrapper)) diff --git a/newrelic/api/datastore_trace.py b/newrelic/api/datastore_trace.py index 1561293f64..0401c79ea5 100644 --- a/newrelic/api/datastore_trace.py +++ b/newrelic/api/datastore_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.datastore_node import DatastoreNode @@ -135,7 +135,7 @@ def create_node(self): ) -def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None): +def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Wraps a method to time datastore queries. :param wrapped: The function to apply the trace to. @@ -158,6 +158,8 @@ def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_p :param database_name: The name of database where the current query is being executed. :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None :rtype: :class:`newrelic.common.object_wrapper.FunctionWrapper` This is typically used to wrap datastore queries such as calls to Redis or @@ -173,7 +175,7 @@ def DatastoreTraceWrapper(wrapped, product, target, operation, host=None, port_p """ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -242,7 +244,7 @@ def _nr_datastore_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_datastore_trace_wrapper_) -def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None): +def datastore_trace(product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None): """Decorator allows datastore query to be timed. :param product: The name of the vendor. @@ -263,6 +265,8 @@ def datastore_trace(product, target, operation, host=None, port_path_or_id=None, :param database_name: The name of database where the current query is being executed. :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to decorate datastore queries such as calls to Redis or ElasticSearch. @@ -284,11 +288,12 @@ def datastore_trace(product, target, operation, host=None, port_path_or_id=None, host=host, port_path_or_id=port_path_or_id, database_name=database_name, + async_wrapper=async_wrapper, ) def wrap_datastore_trace( - module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None + module, object_path, product, target, operation, host=None, port_path_or_id=None, database_name=None, async_wrapper=None ): """Method applies custom timing to datastore query. @@ -314,6 +319,8 @@ def wrap_datastore_trace( :param database_name: The name of database where the current query is being executed. :type database_name: str + :param async_wrapper: An async trace wrapper from newrelic.common.async_wrapper. + :type async_wrapper: callable or None This is typically used to time database query method calls such as Redis GET. @@ -327,5 +334,5 @@ def wrap_datastore_trace( """ wrap_object( - module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name) + module, object_path, DatastoreTraceWrapper, (product, target, operation, host, port_path_or_id, database_name, async_wrapper) ) diff --git a/newrelic/api/external_trace.py b/newrelic/api/external_trace.py index c43c560c6c..2e147df450 100644 --- a/newrelic/api/external_trace.py +++ b/newrelic/api/external_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.external_node import ExternalNode @@ -66,9 +66,9 @@ def create_node(self): ) -def ExternalTraceWrapper(wrapped, library, url, method=None): +def ExternalTraceWrapper(wrapped, library, url, method=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -103,7 +103,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -125,9 +125,9 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def external_trace(library, url, method=None): - return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method) +def external_trace(library, url, method=None, async_wrapper=None): + return functools.partial(ExternalTraceWrapper, library=library, url=url, method=method, async_wrapper=async_wrapper) -def wrap_external_trace(module, object_path, library, url, method=None): - wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method)) +def wrap_external_trace(module, object_path, library, url, method=None, async_wrapper=None): + wrap_object(module, object_path, ExternalTraceWrapper, (library, url, method, async_wrapper)) diff --git a/newrelic/api/function_trace.py b/newrelic/api/function_trace.py index 474c1b2266..85d7617b68 100644 --- a/newrelic/api/function_trace.py +++ b/newrelic/api/function_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_names import callable_name from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.function_node import FunctionNode @@ -89,9 +89,9 @@ def create_node(self): ) -def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def FunctionTraceWrapper(wrapped, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): def dynamic_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -147,7 +147,7 @@ def dynamic_wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) def literal_wrapper(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -171,13 +171,13 @@ def literal_wrapper(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, literal_wrapper) -def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None): +def function_trace(name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None): return functools.partial( - FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup + FunctionTraceWrapper, name=name, group=group, label=label, params=params, terminal=terminal, rollup=rollup, async_wrapper=async_wrapper ) def wrap_function_trace( - module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None + module, object_path, name=None, group=None, label=None, params=None, terminal=False, rollup=None, async_wrapper=None ): - return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup)) + return wrap_object(module, object_path, FunctionTraceWrapper, (name, group, label, params, terminal, rollup, async_wrapper)) diff --git a/newrelic/api/graphql_trace.py b/newrelic/api/graphql_trace.py index 7a2c9ec02f..3d6ae6b09f 100644 --- a/newrelic/api/graphql_trace.py +++ b/newrelic/api/graphql_trace.py @@ -16,7 +16,7 @@ from newrelic.api.time_trace import TimeTrace, current_trace from newrelic.api.transaction import current_transaction -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.graphql_node import GraphQLOperationNode, GraphQLResolverNode @@ -109,9 +109,9 @@ def set_transaction_name(self, priority=None): transaction.set_transaction_name(name, "GraphQL", priority=priority) -def GraphQLOperationTraceWrapper(wrapped): +def GraphQLOperationTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -130,12 +130,12 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_operation_trace(): - return functools.partial(GraphQLOperationTraceWrapper) +def graphql_operation_trace(async_wrapper=None): + return functools.partial(GraphQLOperationTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_operation_trace(module, object_path): - wrap_object(module, object_path, GraphQLOperationTraceWrapper) +def wrap_graphql_operation_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLOperationTraceWrapper, (async_wrapper,)) class GraphQLResolverTrace(TimeTrace): @@ -193,9 +193,9 @@ def create_node(self): ) -def GraphQLResolverTraceWrapper(wrapped): +def GraphQLResolverTraceWrapper(wrapped, async_wrapper=None): def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -214,9 +214,9 @@ def _nr_graphql_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_graphql_trace_wrapper_) -def graphql_resolver_trace(): - return functools.partial(GraphQLResolverTraceWrapper) +def graphql_resolver_trace(async_wrapper=None): + return functools.partial(GraphQLResolverTraceWrapper, async_wrapper=async_wrapper) -def wrap_graphql_resolver_trace(module, object_path): - wrap_object(module, object_path, GraphQLResolverTraceWrapper) +def wrap_graphql_resolver_trace(module, object_path, async_wrapper=None): + wrap_object(module, object_path, GraphQLResolverTraceWrapper, (async_wrapper,)) diff --git a/newrelic/api/memcache_trace.py b/newrelic/api/memcache_trace.py index 6657a9ce27..87f12f9fc7 100644 --- a/newrelic/api/memcache_trace.py +++ b/newrelic/api/memcache_trace.py @@ -15,7 +15,7 @@ import functools from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.memcache_node import MemcacheNode @@ -51,9 +51,9 @@ def create_node(self): ) -def MemcacheTraceWrapper(wrapped, command): +def MemcacheTraceWrapper(wrapped, command, async_wrapper=None): def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -80,9 +80,9 @@ def _nr_wrapper_memcache_trace_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_wrapper_memcache_trace_) -def memcache_trace(command): - return functools.partial(MemcacheTraceWrapper, command=command) +def memcache_trace(command, async_wrapper=None): + return functools.partial(MemcacheTraceWrapper, command=command, async_wrapper=async_wrapper) -def wrap_memcache_trace(module, object_path, command): - wrap_object(module, object_path, MemcacheTraceWrapper, (command,)) +def wrap_memcache_trace(module, object_path, command, async_wrapper=None): + wrap_object(module, object_path, MemcacheTraceWrapper, (command, async_wrapper)) diff --git a/newrelic/api/message_trace.py b/newrelic/api/message_trace.py index be819d7044..f564c41cb4 100644 --- a/newrelic/api/message_trace.py +++ b/newrelic/api/message_trace.py @@ -16,7 +16,7 @@ from newrelic.api.cat_header_mixin import CatHeaderMixin from newrelic.api.time_trace import TimeTrace, current_trace -from newrelic.common.async_wrapper import async_wrapper +from newrelic.common.async_wrapper import async_wrapper as get_async_wrapper from newrelic.common.object_wrapper import FunctionWrapper, wrap_object from newrelic.core.message_node import MessageNode @@ -91,9 +91,9 @@ def create_node(self): ) -def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True): +def MessageTraceWrapper(wrapped, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): - wrapper = async_wrapper(wrapped) + wrapper = async_wrapper if async_wrapper is not None else get_async_wrapper(wrapped) if not wrapper: parent = current_trace() if not parent: @@ -144,7 +144,7 @@ def _nr_message_trace_wrapper_(wrapped, instance, args, kwargs): return FunctionWrapper(wrapped, _nr_message_trace_wrapper_) -def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True): +def message_trace(library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): return functools.partial( MessageTraceWrapper, library=library, @@ -153,10 +153,11 @@ def message_trace(library, operation, destination_type, destination_name, params destination_name=destination_name, params=params, terminal=terminal, + async_wrapper=async_wrapper, ) -def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True): +def wrap_message_trace(module, object_path, library, operation, destination_type, destination_name, params={}, terminal=True, async_wrapper=None): wrap_object( - module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal) + module, object_path, MessageTraceWrapper, (library, operation, destination_type, destination_name, params, terminal, async_wrapper) ) diff --git a/tests/agent_features/_test_async_coroutine_trace.py b/tests/agent_features/_test_async_coroutine_trace.py index 51b81f5f64..1250b8c254 100644 --- a/tests/agent_features/_test_async_coroutine_trace.py +++ b/tests/agent_features/_test_async_coroutine_trace.py @@ -28,6 +28,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -41,6 +42,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_awaitable_timing(event_loop, trace, metric): @@ -79,6 +82,8 @@ def _test(): (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) @pytest.mark.parametrize("yield_from", [True, False]) diff --git a/tests/agent_features/_test_async_generator_trace.py b/tests/agent_features/_test_async_generator_trace.py index 8e7703a0db..f6097bc434 100644 --- a/tests/agent_features/_test_async_generator_trace.py +++ b/tests/agent_features/_test_async_generator_trace.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import gc import sys import time @@ -31,10 +30,10 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace -is_pypy = hasattr(sys, "pypy_version_info") asyncio = pytest.importorskip("asyncio") @@ -47,6 +46,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_async_generator_timing(event_loop, trace, metric): @@ -130,7 +131,7 @@ async def _test(): gen = agen() # kickstart the generator (the try/except logic is inside the # generator) - await anext(gen) + await gen.asend(None) await gen.athrow(ValueError) # consume the generator @@ -198,7 +199,7 @@ async def _test(): gen = agen() # kickstart the coroutine - await anext(gen) + await gen.asend(None) # trace should be ended/recorded by close await gen.aclose() @@ -279,7 +280,7 @@ async def _test(): gen = agen() # kickstart the coroutine - await anext(gen) + await gen.asend(None) assert await gen.asend("foobar") == "foobar" assert _received and _received[0] == "foobar" @@ -311,7 +312,7 @@ async def _test(): gen = agen() # kickstart the coroutine - await anext(gen) + await gen.asend(None) assert await gen.athrow(MyException) == "foobar" @@ -342,7 +343,7 @@ async def _test(): gen = agen() # kickstart the coroutine - await anext(gen) + await gen.asend(None) # async generator will raise StopAsyncIteration with pytest.raises(StopAsyncIteration): @@ -401,10 +402,6 @@ async def _test(): with pytest.raises(RuntimeError): await gen.aclose() - if is_pypy: - gen = None - gc.collect() - event_loop.run_until_complete(_test()) @@ -439,6 +436,26 @@ async def _test(): assert full_metrics[("Function/agen", "")].total_call_time < 0.1 +@validate_transaction_metrics( + "test_complete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], +) +@background_task(name="test_complete_async_generator") +def test_complete_async_generator(event_loop): + @function_trace(name="agen") + async def agen(): + for i in range(5): + yield i + + async def _test(): + gen = agen() + assert [x async for x in gen] == [x for x in range(5)] + + event_loop.run_until_complete(_test()) + + @pytest.mark.parametrize("nr_transaction", [True, False]) def test_incomplete_async_generator(event_loop, nr_transaction): @function_trace(name="agen") @@ -453,35 +470,44 @@ async def _test(): async for _ in c: break - if is_pypy: - # pypy is not guaranteed to delete the coroutine when it goes out - # of scope. This code "helps" pypy along. The test above is really - # just to verify that incomplete coroutines will "eventually" be - # cleaned up. In pypy, unfortunately that means it may not be - # reported all the time. A customer would be expected to call gc - # directly; however, they already have to handle this case since - # incomplete generators are well documented as having problems with - # pypy's gc. - - # See: - # http://doc.pypy.org/en/latest/cpython_differences.html#differences-related-to-garbage-collection-strategies - # https://bitbucket.org/pypy/pypy/issues/736 - del c - import gc - - gc.collect() - if nr_transaction: - _test = background_task(name="test_incomplete_coroutine")(_test) + _test = background_task(name="test_incomplete_async_generator")(_test) event_loop.run_until_complete(_test()) if nr_transaction: _test_incomplete_async_generator = validate_transaction_metrics( - "test_incomplete_coroutine", + "test_incomplete_async_generator", background_task=True, scoped_metrics=[("Function/agen", 1)], rollup_metrics=[("Function/agen", 1)], )(_test_incomplete_async_generator) _test_incomplete_async_generator() + + +def test_incomplete_async_generator_transaction_exited(event_loop): + @function_trace(name="agen") + async def agen(): + for _ in range(5): + yield + + @validate_transaction_metrics( + "test_incomplete_async_generator", + background_task=True, + scoped_metrics=[("Function/agen", 1)], + rollup_metrics=[("Function/agen", 1)], + ) + def _test_incomplete_async_generator(): + c = agen() + @background_task(name="test_incomplete_async_generator") + async def _test(): + async for _ in c: + break + + event_loop.run_until_complete(_test()) + + # Remove generator after transaction completes + del c + + _test_incomplete_async_generator() diff --git a/tests/agent_features/test_async_wrapper_detection.py b/tests/agent_features/test_async_wrapper_detection.py new file mode 100644 index 0000000000..bb1fd3f1e3 --- /dev/null +++ b/tests/agent_features/test_async_wrapper_detection.py @@ -0,0 +1,102 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import functools +import time + +from newrelic.api.background_task import background_task +from newrelic.api.database_trace import database_trace +from newrelic.api.datastore_trace import datastore_trace +from newrelic.api.external_trace import external_trace +from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace +from newrelic.api.memcache_trace import memcache_trace +from newrelic.api.message_trace import message_trace + +from newrelic.common.async_wrapper import generator_wrapper + +from testing_support.fixtures import capture_transaction_metrics +from testing_support.validators.validate_transaction_metrics import ( + validate_transaction_metrics, +) + +trace_metric_cases = [ + (functools.partial(function_trace, name="simple_gen"), "Function/simple_gen"), + (functools.partial(external_trace, library="lib", url="http://foo.com"), "External/foo.com/lib/"), + (functools.partial(database_trace, "select * from foo"), "Datastore/statement/None/foo/select"), + (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), + (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), + (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), +] + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_automatic_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace() + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + + for _ in gen(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 + + +@pytest.mark.parametrize("trace,metric", trace_metric_cases) +def test_manual_generator_trace_wrapper(trace, metric): + metrics = [] + full_metrics = {} + + @capture_transaction_metrics(metrics, full_metrics) + @validate_transaction_metrics( + "test_automatic_generator_trace_wrapper", background_task=True, scoped_metrics=[(metric, 1)], rollup_metrics=[(metric, 1)] + ) + @background_task(name="test_automatic_generator_trace_wrapper") + def _test(): + @trace(async_wrapper=generator_wrapper) + def wrapper_func(): + """Function that returns a generator object, obscuring the automatic introspection of async_wrapper()""" + def gen(): + time.sleep(0.1) + yield + time.sleep(0.1) + return gen() + + for _ in wrapper_func(): + pass + + _test() + + # Check that generators time the total call time (including pauses) + metric_key = (metric, "") + assert full_metrics[metric_key].total_call_time >= 0.2 diff --git a/tests/agent_features/test_coroutine_trace.py b/tests/agent_features/test_coroutine_trace.py index 36e365bc46..a30f1e70a8 100644 --- a/tests/agent_features/test_coroutine_trace.py +++ b/tests/agent_features/test_coroutine_trace.py @@ -31,6 +31,7 @@ from newrelic.api.datastore_trace import datastore_trace from newrelic.api.external_trace import external_trace from newrelic.api.function_trace import function_trace +from newrelic.api.graphql_trace import graphql_operation_trace, graphql_resolver_trace from newrelic.api.memcache_trace import memcache_trace from newrelic.api.message_trace import message_trace @@ -47,6 +48,8 @@ (functools.partial(datastore_trace, "lib", "foo", "bar"), "Datastore/statement/lib/foo/bar"), (functools.partial(message_trace, "lib", "op", "typ", "name"), "MessageBroker/lib/typ/op/Named/name"), (functools.partial(memcache_trace, "cmd"), "Memcache/cmd"), + (functools.partial(graphql_operation_trace), "GraphQL/operation/GraphQL///"), + (functools.partial(graphql_resolver_trace), "GraphQL/resolve/GraphQL/"), ], ) def test_coroutine_timing(trace, metric):