From 6595272b2e5c64310c1eb4e7d1edba0689555c55 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 16 Dec 2022 13:31:59 -0800 Subject: [PATCH] determine spec shape only at mock construction time --- Lib/test/test_unittest/testmock/testasync.py | 8 +++++++- Lib/unittest/mock.py | 16 ++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_unittest/testmock/testasync.py b/Lib/test/test_unittest/testmock/testasync.py index 990b247e5ee975..52a3b71be1ef8d 100644 --- a/Lib/test/test_unittest/testmock/testasync.py +++ b/Lib/test/test_unittest/testmock/testasync.py @@ -303,9 +303,15 @@ def test_spec_normal_methods_on_class_with_mock(self): def test_spec_async_attributes_instance(self): async_instance = AsyncClass() async_instance.async_func_attr = async_func + async_instance.later_async_func_attr = normal_func + + mock_async_instance = Mock(spec_set=async_instance) + + async_instance.later_async_func_attr = async_func - mock_async_instance = Mock(async_instance) self.assertIsInstance(mock_async_instance.async_func_attr, AsyncMock) + # only the shape of the spec at the time of mock construction matters + self.assertNotIsInstance(mock_async_instance.later_async_func_attr, AsyncMock) def test_spec_mock_type_kw(self): def inner_test(mock_type): diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index e37d31dda201c0..583ab74a825531 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -506,10 +506,9 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, _spec_class = None _spec_signature = None - _spec_obj = None + _spec_asyncs = [] if spec is not None and not _is_list(spec): - _spec_obj = spec if isinstance(spec, type): _spec_class = spec else: @@ -518,14 +517,20 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, _spec_as_instance, _eat_self) _spec_signature = res and res[1] - spec = dir(spec) + spec_list = dir(spec) + + for attr in spec_list: + if iscoroutinefunction(getattr(spec, attr, None)): + _spec_asyncs.append(attr) + + spec = spec_list __dict__ = self.__dict__ __dict__['_spec_class'] = _spec_class - __dict__['_spec_obj'] = _spec_obj __dict__['_spec_set'] = spec_set __dict__['_spec_signature'] = _spec_signature __dict__['_mock_methods'] = spec + __dict__['_spec_asyncs'] = _spec_asyncs def __get_return_value(self): ret = self._mock_return_value @@ -1015,8 +1020,7 @@ def _get_child_mock(self, /, **kw): For non-callable mocks the callable variant will be used (rather than any custom subclass).""" _new_name = kw.get("_new_name") - _spec_val = getattr(self.__dict__["_spec_obj"], _new_name, None) - if _spec_val is not None and asyncio.iscoroutinefunction(_spec_val): + if _new_name in self.__dict__['_spec_asyncs']: return AsyncMock(**kw) if self._mock_sealed: