From 19a39521dd57709e380039d06efe98067b7b5f89 Mon Sep 17 00:00:00 2001 From: Wes Bonelli Date: Wed, 26 Jul 2023 14:12:02 -0400 Subject: [PATCH] feat(set_env): add set_env contextmanager utility --- modflow_devtools/misc.py | 33 ++++++++++++++++++++++++++++++ modflow_devtools/test/test_misc.py | 17 +++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/modflow_devtools/misc.py b/modflow_devtools/misc.py index f42468d..c326c0b 100644 --- a/modflow_devtools/misc.py +++ b/modflow_devtools/misc.py @@ -29,6 +29,39 @@ def set_dir(path: PathLike): print(f"Returned to previous directory: {origin}") +@contextmanager +def set_env(*remove, **update): + """ + Temporarily updates the ``os.environ`` dictionary in-place. + + Referenced from https://stackoverflow.com/a/34333710/6514033. + + The ``os.environ`` dictionary is updated in-place so that the modification + is sure to work in all situations. + + :param remove: Environment variables to remove. + :param update: Dictionary of environment variables and values to add/update. + """ + env = environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after] + + class add_sys_path: """ Context manager for temporarily editing the system path diff --git a/modflow_devtools/test/test_misc.py b/modflow_devtools/test/test_misc.py index 9ff517e..5003940 100644 --- a/modflow_devtools/test/test_misc.py +++ b/modflow_devtools/test/test_misc.py @@ -11,6 +11,7 @@ get_packages, has_package, set_dir, + set_env, ) @@ -21,6 +22,22 @@ def test_set_dir(tmp_path): assert Path(os.getcwd()) != tmp_path +def test_set_env(tmp_path): + # test adding a variable + key = "TEST_ENV" + val = "test" + assert environ.get(key) is None + with set_env(**{key: val}): + assert environ.get(key) == val + with set_env(TEST_ENV=val): + assert environ.get(key) == val + + # test removing a variable + with set_env(**{key: val}): + with set_env(key): + assert environ.get(key) is None + + _repos_path = environ.get("REPOS_PATH") if _repos_path is None: _repos_path = Path(__file__).parent.parent.parent.parent