Skip to content

Commit

Permalink
feat: proper test runner handler
Browse files Browse the repository at this point in the history
  • Loading branch information
cih9088 committed Jan 10, 2023
1 parent abdb636 commit 3031515
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 11 deletions.
2 changes: 2 additions & 0 deletions otx/algorithms/classification/adapters/mmcls/nncf/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def build_nncf_classifier( # pylint: disable=too-many-locals
else:
# pytorch ckpt
state_to_build_nncf = state_dict
if "state_dict" in state_dict:
state_to_build_nncf = state_dict["state_dict"]

# This data and state dict will be used to build NNCF graph later
# when loading NNCF model
Expand Down
4 changes: 0 additions & 4 deletions otx/algorithms/common/tasks/nncf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io
import json
import os
import tempfile
from collections.abc import Mapping
from copy import deepcopy
from typing import Dict, List, Optional
Expand Down Expand Up @@ -84,9 +83,6 @@ def __init__(self, task_environment: TaskEnvironment, **kwargs):
self._optimization_methods: List[OptimizationMethod] = []
self._precision = [ModelPrecision.FP32]

self._scratch_space = tempfile.mkdtemp(prefix="otx-nncf-scratch-")
logger.info(f"Scratch space created at {self._scratch_space}")

# Extra control variables.
self._training_work_dir = None
self._is_training = False
Expand Down
2 changes: 2 additions & 0 deletions otx/algorithms/detection/adapters/mmdet/nncf/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def build_nncf_detector( # pylint: disable=too-many-locals,too-many-statements
else:
# pytorch ckpt
state_to_build_nncf = state_dict
if "state_dict" in state_dict:
state_to_build_nncf = state_dict["state_dict"]

# This data and state dict will be used to build NNCF graph later
# when loading NNCF model
Expand Down
2 changes: 2 additions & 0 deletions otx/algorithms/segmentation/adapters/mmseg/nncf/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def build_nncf_segmentor( # noqa: C901 # pylint: disable=too-many-locals
else:
# pytorch ckpt
state_to_build_nncf = state_dict
if "state_dict" in state_dict:
state_to_build_nncf = state_dict["state_dict"]

# This data and state dict will be used to build NNCF graph later
# when loading NNCF model
Expand Down
73 changes: 66 additions & 7 deletions otx/cli/utils/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import asyncio
import json
import os
import shutil
import subprocess # nosec
import sys

import pytest
Expand Down Expand Up @@ -49,12 +49,71 @@ def get_template_dir(template, root) -> str:
return template_work_dir


def check_run(cmd, **kwargs):
p = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True, bufsize=1, **kwargs)
for c in iter(lambda: p.stderr.read(1), ""):
sys.stderr.write(c)
p.communicate()
assert p.returncode == 0, "The process returned non zero."
def runner(
cmd,
stdout_stream=sys.stdout.buffer,
stderr_stream=sys.stderr.buffer,
**kwargs,
):
async def stream_handler(in_stream, out_stream):
output = []
while True:
line = await in_stream.readline()
if not line:
break
output.append(line)
out_stream.write(line) # assume it doesn't block
return b"".join(output)

async def run_and_capture(cmd):
environ = os.environ.copy()
environ["PYTHONUNBUFFERED"] = "1"
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=environ,
**kwargs,
)

try:
stdout, stderr = await asyncio.gather(
stream_handler(process.stdout, stdout_stream),
stream_handler(process.stderr, stderr_stream),
)
except Exception:
process.kill()
raise
finally:
rc = await process.wait()
return rc, stdout, stderr

if os.name == "nt":
# for subprocess' pipes on Windows
loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(loop)
else:
loop = asyncio.get_event_loop()
rc, stdout, stderr = loop.run_until_complete(run_and_capture(cmd))
loop.close()

return rc, stdout, stderr


def check_run(cmd):
rc, _, stderr = runner(cmd)

sys.stdout.flush()
sys.stderr.flush()

if rc != 0:
stderr = stderr.decode("utf-8").splitlines()
i = 0
for i, line in enumerate(stderr):
if line.startswith("Traceback"):
break
stderr = "\n".join(stderr[i:])
assert rc == 0, stderr


def otx_train_testing(template, root, otx_dir, args):
Expand Down

0 comments on commit 3031515

Please sign in to comment.