diff --git a/changelog.d/11607.misc b/changelog.d/11607.misc new file mode 100644 index 000000000000..e82f46776364 --- /dev/null +++ b/changelog.d/11607.misc @@ -0,0 +1 @@ +Improve opentracing support for requests which use a `ResponseCache`. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ead2198e14fe..b9c1cbffa5c5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -172,7 +172,7 @@ async def upgrade_room( user_id = requester.user.to_string() # Check if this room is already being upgraded by another person - for key in self._upgrade_response_cache.pending_result_cache: + for key in self._upgrade_response_cache.keys(): if key[0] == old_room_id and key[1] != user_id: # Two different people are trying to upgrade the same room. # Send the second an error. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 20ce294209ad..1548c5fd5786 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import collections import inspect import itertools @@ -55,7 +56,26 @@ _T = TypeVar("_T") -class ObservableDeferred(Generic[_T]): +class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta): + """Abstract base class defining the consumer interface of ObservableDeferred""" + + __slots__ = () + + @abc.abstractmethod + def observe(self) -> "defer.Deferred[_T]": + """Add a new observer for this ObservableDeferred + + This returns a brand new deferred that is resolved when the underlying + deferred is resolved. Interacting with the returned deferred does not + effect the underlying deferred. + + Note that the returned Deferred doesn't follow the Synapse logcontext rules - + you will probably want to `make_deferred_yieldable` it. + """ + ... + + +class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): """Wraps a deferred object so that we can add observer deferreds. These observer deferreds do not affect the callback chain of the original deferred. diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 88ccf443377c..a3eb5f741bfc 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -12,19 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Generic, + Iterable, + Optional, + TypeVar, +) import attr from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.opentracing import ( + active_span, + start_active_span, + start_active_span_follows_from, +) from synapse.util import Clock -from synapse.util.async_helpers import ObservableDeferred +from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred from synapse.util.caches import register_cache logger = logging.getLogger(__name__) +if TYPE_CHECKING: + import opentracing + # the type of the key in the cache KV = TypeVar("KV") @@ -54,6 +72,20 @@ class ResponseCacheContext(Generic[KV]): """ +@attr.s(auto_attribs=True) +class ResponseCacheEntry: + result: AbstractObservableDeferred + """The (possibly incomplete) result of the operation. + + Note that we continue to store an ObservableDeferred even after the operation + completes (rather than switching to an immediate value), since that makes it + easier to cache Failure results. + """ + + opentracing_span_context: "Optional[opentracing.SpanContext]" + """The opentracing span which generated/is generating the result""" + + class ResponseCache(Generic[KV]): """ This caches a deferred response. Until the deferred completes it will be @@ -63,10 +95,7 @@ class ResponseCache(Generic[KV]): """ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): - # This is poorly-named: it includes both complete and incomplete results. - # We keep complete results rather than switching to absolute values because - # that makes it easier to cache Failure results. - self.pending_result_cache: Dict[KV, ObservableDeferred] = {} + self._result_cache: Dict[KV, ResponseCacheEntry] = {} self.clock = clock self.timeout_sec = timeout_ms / 1000.0 @@ -75,56 +104,63 @@ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self) -> int: - return len(self.pending_result_cache) + return len(self._result_cache) def __len__(self) -> int: return self.size() - def get(self, key: KV) -> Optional[defer.Deferred]: - """Look up the given key. + def keys(self) -> Iterable[KV]: + """Get the keys currently in the result cache - Returns a new Deferred (which also doesn't follow the synapse - logcontext rules). You will probably want to make_deferred_yieldable the result. + Returns both incomplete entries, and (if the timeout on this cache is non-zero), + complete entries which are still in the cache. - If there is no entry for the key, returns None. + Note that the returned iterator is not safe in the face of concurrent execution: + behaviour is undefined if `wrap` is called during iteration. + """ + return self._result_cache.keys() + + def _get(self, key: KV) -> Optional[ResponseCacheEntry]: + """Look up the given key. Args: - key: key to get/set in the cache + key: key to get in the cache Returns: - None if there is no entry for this key; otherwise a deferred which - resolves to the result. + The entry for this key, if any; else None. """ - result = self.pending_result_cache.get(key) - if result is not None: + entry = self._result_cache.get(key) + if entry is not None: self._metrics.inc_hits() - return result.observe() + return entry else: self._metrics.inc_misses() return None def _set( - self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]" - ) -> "defer.Deferred[RV]": + self, + context: ResponseCacheContext[KV], + deferred: "defer.Deferred[RV]", + opentracing_span_context: "Optional[opentracing.SpanContext]", + ) -> ResponseCacheEntry: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, you should wrap normal synapse deferreds with synapse.logging.context.run_in_background). - Returns a new Deferred (which also doesn't follow the synapse logcontext rules). - You will probably want to make_deferred_yieldable the result. - Args: context: Information about the cache miss deferred: The deferred which resolves to the result. + opentracing_span_context: An opentracing span wrapping the calculation Returns: - A new deferred which resolves to the actual result. + The cache entry object. """ result = ObservableDeferred(deferred, consumeErrors=True) key = context.cache_key - self.pending_result_cache[key] = result + entry = ResponseCacheEntry(result, opentracing_span_context) + self._result_cache[key] = entry def on_complete(r: RV) -> RV: # if this cache has a non-zero timeout, and the callback has not cleared @@ -132,18 +168,18 @@ def on_complete(r: RV) -> RV: # its removal later. if self.timeout_sec and context.should_cache: self.clock.call_later( - self.timeout_sec, self.pending_result_cache.pop, key, None + self.timeout_sec, self._result_cache.pop, key, None ) else: # otherwise, remove the result immediately. - self.pending_result_cache.pop(key, None) + self._result_cache.pop(key, None) return r - # make sure we do this *after* adding the entry to pending_result_cache, + # make sure we do this *after* adding the entry to result_cache, # in case the result is already complete (in which case flipping the order would # leave us with a stuck entry in the cache). result.addBoth(on_complete) - return result.observe() + return entry async def wrap( self, @@ -189,20 +225,41 @@ async def handle_request(request): Returns: The result of the callback (from the cache, or otherwise) """ - result = self.get(key) - if not result: + entry = self._get(key) + if not entry: logger.debug( "[%s]: no cached result for [%s], calculating new one", self._name, key ) context = ResponseCacheContext(cache_key=key) if cache_context: kwargs["cache_context"] = context - d = run_in_background(callback, *args, **kwargs) - result = self._set(context, d) - elif not isinstance(result, defer.Deferred) or result.called: + + span_context: Optional[opentracing.SpanContext] = None + + async def cb() -> RV: + # NB it is important that we do not `await` before setting span_context! + nonlocal span_context + with start_active_span(f"ResponseCache[{self._name}].calculate"): + span = active_span() + if span: + span_context = span.context + return await callback(*args, **kwargs) + + d = run_in_background(cb) + entry = self._set(context, d, span_context) + return await make_deferred_yieldable(entry.result.observe()) + + result = entry.result.observe() + if result.called: logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: logger.info( "[%s]: using incomplete cached result for [%s]", self._name, key ) - return await make_deferred_yieldable(result) + + span_context = entry.opentracing_span_context + with start_active_span_follows_from( + f"ResponseCache[{self._name}].wait", + contexts=(span_context,) if span_context else (), + ): + return await make_deferred_yieldable(result) diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index 1e83ef2f33d5..025b73e32f90 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from unittest.mock import Mock + from parameterized import parameterized from twisted.internet import defer @@ -60,10 +63,15 @@ def test_cache_hit(self): self.successResultOf(wrap_d), "initial wrap result should be the same", ) + + # a second call should return the result without a call to the wrapped function + unexpected = Mock(spec=()) + wrap2_d = defer.ensureDeferred(cache.wrap(0, unexpected)) + unexpected.assert_not_called() self.assertEqual( expected_result, - self.successResultOf(cache.get(0)), - "cache should have the result", + self.successResultOf(wrap2_d), + "cache should still have the result", ) def test_cache_miss(self): @@ -80,7 +88,7 @@ def test_cache_miss(self): self.successResultOf(wrap_d), "initial wrap result should be the same", ) - self.assertIsNone(cache.get(0), "cache should not have the result now") + self.assertCountEqual([], cache.keys(), "cache should not have the result now") def test_cache_expire(self): cache = self.with_cache("short_cache", ms=1000) @@ -92,16 +100,20 @@ def test_cache_expire(self): ) self.assertEqual(expected_result, self.successResultOf(wrap_d)) + + # a second call should return the result without a call to the wrapped function + unexpected = Mock(spec=()) + wrap2_d = defer.ensureDeferred(cache.wrap(0, unexpected)) + unexpected.assert_not_called() self.assertEqual( expected_result, - self.successResultOf(cache.get(0)), + self.successResultOf(wrap2_d), "cache should still have the result", ) # cache eviction timer is handled self.reactor.pump((2,)) - - self.assertIsNone(cache.get(0), "cache should not have the result now") + self.assertCountEqual([], cache.keys(), "cache should not have the result now") def test_cache_wait_hit(self): cache = self.with_cache("neutral_cache") @@ -133,16 +145,21 @@ def test_cache_wait_expire(self): self.reactor.pump((1, 1)) self.assertEqual(expected_result, self.successResultOf(wrap_d)) + + # a second call should immediately return the result without a call to the + # wrapped function + unexpected = Mock(spec=()) + wrap2_d = defer.ensureDeferred(cache.wrap(0, unexpected)) + unexpected.assert_not_called() self.assertEqual( expected_result, - self.successResultOf(cache.get(0)), + self.successResultOf(wrap2_d), "cache should still have the result", ) # (1 + 1 + 2) > 3.0, cache eviction timer is handled self.reactor.pump((2,)) - - self.assertIsNone(cache.get(0), "cache should not have the result now") + self.assertCountEqual([], cache.keys(), "cache should not have the result now") @parameterized.expand([(True,), (False,)]) def test_cache_context_nocache(self, should_cache: bool): @@ -183,10 +200,16 @@ async def non_caching(o: str, cache_context: ResponseCacheContext[int]): self.assertEqual(expected_result, self.successResultOf(wrap2_d)) if should_cache: + unexpected = Mock(spec=()) + wrap3_d = defer.ensureDeferred(cache.wrap(0, unexpected)) + unexpected.assert_not_called() self.assertEqual( expected_result, - self.successResultOf(cache.get(0)), + self.successResultOf(wrap3_d), "cache should still have the result", ) + else: - self.assertIsNone(cache.get(0), "cache should not have the result") + self.assertCountEqual( + [], cache.keys(), "cache should not have the result now" + )