diff --git a/traits/etsconfig/etsconfig.py b/traits/etsconfig/etsconfig.py index 45b66cd02..b718c487f 100644 --- a/traits/etsconfig/etsconfig.py +++ b/traits/etsconfig/etsconfig.py @@ -12,10 +12,9 @@ # Standard library imports. -import sys +import contextlib import os -from os import path -from contextlib import contextmanager +import sys class ETSToolkitError(RuntimeError): @@ -94,7 +93,7 @@ def get_application_data(self, create=False): - The actual location differs between operating systems. - """ + """ if self._application_data is None: self._application_data = self._initialize_application_data( create=create @@ -155,9 +154,9 @@ def get_application_home(self, create=False): - The actual location differs between operating systems. - """ + """ if self._application_home is None: - self._application_home = path.join( + self._application_home = os.path.join( self.get_application_data(create=create), self._get_application_dirname(), ) @@ -208,7 +207,7 @@ def company(self, company): def company(self): self._company = None - @contextmanager + @contextlib.contextmanager def provisional_toolkit(self, toolkit): """ Perform an operation with toolkit provisionally set @@ -401,8 +400,8 @@ def _get_application_dirname(self): main_mod = sys.modules.get("__main__", None) if main_mod is not None: if hasattr(main_mod, "__file__"): - main_mod_file = path.abspath(main_mod.__file__) - dirname = path.basename(path.dirname(main_mod_file)) + main_mod_file = os.path.abspath(main_mod.__file__) + dirname = os.path.basename(os.path.dirname(main_mod_file)) return dirname diff --git a/traits/etsconfig/tests/test_etsconfig.py b/traits/etsconfig/tests/test_etsconfig.py index 5bc247e9c..cffbb807c 100644 --- a/traits/etsconfig/tests/test_etsconfig.py +++ b/traits/etsconfig/tests/test_etsconfig.py @@ -13,6 +13,7 @@ # Standard library imports. import contextlib import os +import pathlib import shutil import sys import tempfile @@ -69,7 +70,7 @@ def temporary_home_directory(): with temporary_directory() as temp_home: with restore_mapping_entry(os.environ, home_var): os.environ[home_var] = temp_home - yield + yield temp_home @contextlib.contextmanager @@ -109,11 +110,9 @@ def setUp(self): # Make a fresh instance each time. self.ETSConfig = type(ETSConfig)() - - def run(self, result=None): - # Extend TestCase.run to use a temporary home directory. - with temporary_home_directory(): - super().run(result) + with contextlib.ExitStack() as stack: + self._temp_home = stack.enter_context(temporary_home_directory()) + self.addCleanup(stack.pop_all().close) ########################################################################### # 'ETSConfigTestCase' interface. @@ -251,10 +250,10 @@ def test_default_application_home(self): (dirname, app_name) = os.path.split(app_home) self.assertEqual(dirname, self.ETSConfig.application_data) - - # The assumption here is that the test was run using unittest and not - # a different test runner e.g. using "python -m unittest ...". - self.assertEqual(app_name, "unittest") + self.assertEqual( + app_name, + pathlib.Path(sys.modules["__main__"].__file__).parts[-2] + ) def test_delete_application_home(self): # given