From 3031515ac2687b443951604347c6e237985df729 Mon Sep 17 00:00:00 2001 From: Inhyuk Andy Cho Date: Tue, 10 Jan 2023 11:41:33 +0900 Subject: [PATCH] feat: proper test runner handler --- .../adapters/mmcls/nncf/builder.py | 2 + otx/algorithms/common/tasks/nncf_base.py | 4 - .../detection/adapters/mmdet/nncf/builder.py | 2 + .../adapters/mmseg/nncf/builder.py | 2 + otx/cli/utils/tests.py | 73 +++++++++++++++++-- 5 files changed, 72 insertions(+), 11 deletions(-) diff --git a/otx/algorithms/classification/adapters/mmcls/nncf/builder.py b/otx/algorithms/classification/adapters/mmcls/nncf/builder.py index 52cfa7beb44..aa4508d6120 100644 --- a/otx/algorithms/classification/adapters/mmcls/nncf/builder.py +++ b/otx/algorithms/classification/adapters/mmcls/nncf/builder.py @@ -72,6 +72,8 @@ def build_nncf_classifier( # pylint: disable=too-many-locals else: # pytorch ckpt state_to_build_nncf = state_dict + if "state_dict" in state_dict: + state_to_build_nncf = state_dict["state_dict"] # This data and state dict will be used to build NNCF graph later # when loading NNCF model diff --git a/otx/algorithms/common/tasks/nncf_base.py b/otx/algorithms/common/tasks/nncf_base.py index 9fdaab3c2ee..68d32e2d622 100644 --- a/otx/algorithms/common/tasks/nncf_base.py +++ b/otx/algorithms/common/tasks/nncf_base.py @@ -18,7 +18,6 @@ import io import json import os -import tempfile from collections.abc import Mapping from copy import deepcopy from typing import Dict, List, Optional @@ -84,9 +83,6 @@ def __init__(self, task_environment: TaskEnvironment, **kwargs): self._optimization_methods: List[OptimizationMethod] = [] self._precision = [ModelPrecision.FP32] - self._scratch_space = tempfile.mkdtemp(prefix="otx-nncf-scratch-") - logger.info(f"Scratch space created at {self._scratch_space}") - # Extra control variables. self._training_work_dir = None self._is_training = False diff --git a/otx/algorithms/detection/adapters/mmdet/nncf/builder.py b/otx/algorithms/detection/adapters/mmdet/nncf/builder.py index 204f5e5cc3b..59a68eef17a 100644 --- a/otx/algorithms/detection/adapters/mmdet/nncf/builder.py +++ b/otx/algorithms/detection/adapters/mmdet/nncf/builder.py @@ -77,6 +77,8 @@ def build_nncf_detector( # pylint: disable=too-many-locals,too-many-statements else: # pytorch ckpt state_to_build_nncf = state_dict + if "state_dict" in state_dict: + state_to_build_nncf = state_dict["state_dict"] # This data and state dict will be used to build NNCF graph later # when loading NNCF model diff --git a/otx/algorithms/segmentation/adapters/mmseg/nncf/builder.py b/otx/algorithms/segmentation/adapters/mmseg/nncf/builder.py index 0618ee9664c..8d77f88170b 100644 --- a/otx/algorithms/segmentation/adapters/mmseg/nncf/builder.py +++ b/otx/algorithms/segmentation/adapters/mmseg/nncf/builder.py @@ -77,6 +77,8 @@ def build_nncf_segmentor( # noqa: C901 # pylint: disable=too-many-locals else: # pytorch ckpt state_to_build_nncf = state_dict + if "state_dict" in state_dict: + state_to_build_nncf = state_dict["state_dict"] # This data and state dict will be used to build NNCF graph later # when loading NNCF model diff --git a/otx/cli/utils/tests.py b/otx/cli/utils/tests.py index 95ddf8ae752..9ef632f96b3 100644 --- a/otx/cli/utils/tests.py +++ b/otx/cli/utils/tests.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions # and limitations under the License. +import asyncio import json import os import shutil -import subprocess # nosec import sys import pytest @@ -49,12 +49,71 @@ def get_template_dir(template, root) -> str: return template_work_dir -def check_run(cmd, **kwargs): - p = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True, bufsize=1, **kwargs) - for c in iter(lambda: p.stderr.read(1), ""): - sys.stderr.write(c) - p.communicate() - assert p.returncode == 0, "The process returned non zero." +def runner( + cmd, + stdout_stream=sys.stdout.buffer, + stderr_stream=sys.stderr.buffer, + **kwargs, +): + async def stream_handler(in_stream, out_stream): + output = [] + while True: + line = await in_stream.readline() + if not line: + break + output.append(line) + out_stream.write(line) # assume it doesn't block + return b"".join(output) + + async def run_and_capture(cmd): + environ = os.environ.copy() + environ["PYTHONUNBUFFERED"] = "1" + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=environ, + **kwargs, + ) + + try: + stdout, stderr = await asyncio.gather( + stream_handler(process.stdout, stdout_stream), + stream_handler(process.stderr, stderr_stream), + ) + except Exception: + process.kill() + raise + finally: + rc = await process.wait() + return rc, stdout, stderr + + if os.name == "nt": + # for subprocess' pipes on Windows + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + rc, stdout, stderr = loop.run_until_complete(run_and_capture(cmd)) + loop.close() + + return rc, stdout, stderr + + +def check_run(cmd): + rc, _, stderr = runner(cmd) + + sys.stdout.flush() + sys.stderr.flush() + + if rc != 0: + stderr = stderr.decode("utf-8").splitlines() + i = 0 + for i, line in enumerate(stderr): + if line.startswith("Traceback"): + break + stderr = "\n".join(stderr[i:]) + assert rc == 0, stderr def otx_train_testing(template, root, otx_dir, args):