Skip to content

Commit

Permalink
adding test script for github actions (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
askhade authored Sep 22, 2020
1 parent 5af49eb commit 2fd86ac
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions workflow_scripts/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import onnx
from pathlib import Path
import subprocess
import sys

def run_lfs_install():
result = subprocess.run(['git', 'lfs', 'install'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("Git LFS install completed with return code=" + str(result.returncode))

def pull_lfs_file(file_name):
result = subprocess.run(['git', 'lfs', 'pull', '--include', file_name, '--exclude', '\"\"'], cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print("LFS pull completed with return code=" + str(result.returncode))

cwd_path = Path.cwd()

# obtain list of added or modified files in this PR
obtain_diff = subprocess.Popen(['git', 'diff', '--name-only', '--diff-filter=AM', 'origin/master', 'HEAD'],
cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdoutput, stderroutput = obtain_diff.communicate()
diff_list = stdoutput.split()

# identify list of changed onnx models in model Zoo
model_list = [str(model).replace("b'","").replace("'", "") for model in diff_list if ".onnx" in str(model)]

# run lfs install before starting the tests
run_lfs_install()

print("\n=== Running ONNX Checker on added models ===\n")
# run checker on each model
failed_models = []
for model_path in model_list:
model_name = model_path.split('/')[-1]
print("Testing:", model_name)

try:
pull_lfs_file(model_path)
model = onnx.load(model_path)
onnx.checker.check_model(model)
print("Model", model_name, "has been successfully checked!")
except Exception as e:
print(e)
failed_models.append(model_path)

if len(failed_models) != 0:
print(str(len(failed_models)) +" models failed onnx checker.")
sys.exit(1)

print(len(model_list), "model(s) checked.")

0 comments on commit 2fd86ac

Please sign in to comment.