Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use wget to download models in the onnx model zoo #1433

Merged
merged 11 commits into from
May 19, 2022
156 changes: 87 additions & 69 deletions test/onnx-model-zoo/CheckONNXModelZoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import os
import sys
import signal
import argparse
import subprocess
import tempfile
Expand All @@ -30,7 +31,7 @@
$ cd models
$ ln -s /onnx_mlir/test/onnx-model/test/onnx-model-zoo/CheckONNXModelZoo.py CheckONNXModelZoo.py
$ ln -s /onnx_mlir/utils/RunONNXModel.py RunONNXModel.py
$ VERBOSE=1 ONNX_MLIR_HOME=/onnx-mlir/build/Release/ python CheckONNXModelZoo.py -pull-models -m mnist-8 -compile_args="-O3 -mcpu=z14"
$ VERBOSE=1 ONNX_MLIR_HOME=/onnx-mlir/build/Release/ python CheckONNXModelZoo.py -m mnist-8 -compile_args="-O3 -mcpu=z14"
"""

if (not os.environ.get('ONNX_MLIR_HOME', None)):
Expand All @@ -47,6 +48,8 @@
"""
VERBOSE = int(os.environ.get('VERBOSE', 0))

ONNX_MODEL_ZOO_URL = "https://github.com/onnx/models/raw/main"


def log_l1(*args):
if (VERBOSE >= 1):
Expand All @@ -60,17 +63,13 @@ def log_l2(*args):

"""Commands will be called in this script.
"""
FIND_MODEL_PATHS_CMD = ['find', '.', '-type', 'f', '-name', '*.tar.gz']
# git lfs pull --include="${onnx_model}" --exclude=""
PULL_CMD = ['git', 'lfs', 'pull', '--exclude=\"\"']
# git lfs pointer --file = "${onnx_model}" > ${onnx_model}.pt
CLEAN_CMD = ['git', 'lfs', 'pointer']
# git checkout file_path
CHECKOUT_CMD = ['git', 'checkout']
# tar -xzvf file.tar.gz
# `-mindepth 2` is to ignore the current folder but subfolders.
FIND_MODEL_PATHS_CMD = [
'find', '.', '-mindepth', '2', '-type', 'f', '-name', '*.tar.gz'
]
UNTAR_CMD = ['tar', '-xzvf']
WGET_CMD = ['wget', '--no-check-certificate', '--timestamping']
RM_CMD = ['rm']
MV_CMD = ['mv']
# Compile, run and verify an onnx model.
RUN_ONNX_MODEL = ['python', 'RunONNXModel.py']

Expand All @@ -81,10 +80,11 @@ def execute_commands(cmds):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout, stderr = out.communicate()
if stderr:
return (False, stderr.decode("utf-8"))
else:
return (True, stdout.decode("utf-8"))
if out.returncode == -signal.SIGSEGV:
return (False, "Segfault")
if out.returncode != 0:
return (False, stderr.decode("utf-8") + stdout.decode("utf-8"))
return (True, stdout.decode("utf-8"))


def execute_commands_to_file(cmds, ofile):
Expand Down Expand Up @@ -112,6 +112,36 @@ def execute_commands_to_file(cmds, ofile):
"emotion-ferplus-2",
}

int8_models = {
"bertsquad-12-int8",
"inception-v1-12-int8",
"googlenet-12-int8",
"zfnet512-12-int8",
"caffenet-12-int8",
"mobilenetv2-12-int8",
"squeezenet1.0-12-int8",
"densenet-12-int8",
"resnet50-v1-12-int8",
"efficientnet-lite4-11-int8",
"bvlcalexnet-12-int8",
"vgg16-12-int8",
"shufflenet-v2-12-int8",
"yolov3-12-int8",
"FasterRCNN-12-int8",
"fcn-resnet50-12-int8",
"ssd-12-int8",
"ssd_mobilenet_v1_12-int8",
"MaskRCNN-12-int8",
}

excluded_models = deprecated_models.union(int8_models)

# Additional information passed to RunONNXModel.py.
RunONNXModel_additional_options = {
"t5-decoder-with-lm-head-12": ['--shape_info=0:1x2,1:1x2x768'],
"t5-encoder-12": ['--shape_info=0:1x2,1:1x2x768']
}

# States
NO_TEST = 0
TEST_FAILED = 1
Expand All @@ -126,18 +156,16 @@ def obtain_all_model_paths():
model_names = [
path.split('/')[-1][:-len(".tag.gz")] for path in model_paths
] # remove .tag.gz
deprecated_names = set(model_names).intersection(deprecated_models)
excluded_names = set(model_names).intersection(excluded_models)

log_l1('\n')
deprecated_msg = ""
if (len(deprecated_names) != 0):
deprecated_msg = "where " + \
str(len(deprecated_names)) + \
" models are deprecated (using very old opsets, e.g. <= 3)"
log_l1("# There are {} models in the ONNX model zoo {}".format(
len(model_paths), deprecated_msg))
log_l1("See https://github.com/onnx/models/pull/389",
"for a list of deprecated models\n")
excluded_msg = ""
if (len(excluded_names) != 0):
excluded_msg = "where " + \
str(len(excluded_names)) + \
" models are not checked because of old opsets or quantization"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of line continuation, it's probably better to use parentheses for multi-line expressions:

excluded_msg = ("where " +
    str(len(excluded_names)) +
    " models are not checked because of old opsets or quantization")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminding this! It's better.

print("There are {} models in the ONNX model zoo {}.".format(
len(model_paths), excluded_msg))
return model_names, model_paths


Expand Down Expand Up @@ -179,7 +207,9 @@ def check_model(model_path, model_name, compile_args):
has_data_sets = True
data_set = pb_files.split('\n')[0]
if (not has_data_sets):
log_l1("Warning: This model does not have test data sets.")
log_l1(
"Warning: model {} does not have test data sets. Will check the model with random data."
.format(model_name))

# compile, run and verify.
log_l1("Checking the model {} ...".format(model_name))
Expand All @@ -189,39 +219,35 @@ def check_model(model_path, model_name, compile_args):
if has_data_sets:
options += ['--verify=ref']
options += ['--data_folder={}'.format(data_set)]
if model_name in RunONNXModel_additional_options:
options += RunONNXModel_additional_options[model_name]
ok, msg = execute_commands(RUN_ONNX_MODEL + [onnx_file] + options)
state = TEST_PASSED if ok else TEST_FAILED
log_l1(msg)
log_l1("[{}] {}".format(model_name, msg))
return state


def pull_and_check_model(model_path, compile_args, pull_models, keep_model):
def pull_and_check_model(model_path, compile_args, keep_model):
state = NO_TEST

# Ignore deprecated models.
model_tag_gz = "./" + model_path.split('/')[-1]
model_name = model_path.split('/')[-1][:-len(".tag.gz")] # remove .tag.gz
if model_name in deprecated_models:
log_l1("The model {} is deprecated. Ignored.".format(model_name))
if model_name in excluded_models:
log_l1("The model {} is excluded. Ignored.".format(model_name))
return state, model_name

# pull the model.
if pull_models:
log_l1('Downloading {}'.format(model_path))
pull_cmd = PULL_CMD + ['--include={}'.format(model_path)]
ok, _ = execute_commands(pull_cmd)
if not ok:
log_l1("Failed to pull the model {}. Ignored.".format(model_name))
model_url = ONNX_MODEL_ZOO_URL + '/' + model_path
log_l1('Downloading {}'.format(model_url))
ok, _ = execute_commands(WGET_CMD + [model_url])

# check the model.
state = check_model(model_path, model_name, compile_args)
state = check_model(model_tag_gz, model_name, compile_args)

if pull_models and (not keep_model):
if not keep_model:
# remove the model to save the storage space.
clean_cmd = CLEAN_CMD + ['--file={}'.format(model_path)]
execute_commands_to_file(clean_cmd, '{}.pt'.format(model_path))
execute_commands(RM_CMD + [model_path])
execute_commands(MV_CMD + ['{}.pt'.format(model_path), model_path])
execute_commands(CHECKOUT_CMD + [model_path])
execute_commands(RM_CMD + [model_tag_gz])

return state, model_name

Expand Down Expand Up @@ -250,21 +276,12 @@ def main():
parser.add_argument(
'-compile_args',
help="Options passing to onnx-mlir to compile a model.")
parallel_group = parser.add_mutually_exclusive_group()
parallel_group.add_argument(
'-njobs',
type=int,
default=1,
help="The number of processes in parallel."
" The large -njobs is, the more disk space is needed"
" for downloaded onnx models. Default 1.")
parallel_group.add_argument(
'-pull_models',
action='store_true',
help="Pull models from the remote git repository."
" This requires git-lfs. Please follow the instruction here to install"
" git-lfs: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage."
)
parser.add_argument('-njobs',
type=int,
default=1,
help="The number of processes in parallel."
" The large -njobs is, the more disk space is needed"
" for downloaded onnx models. Default 1.")
args = parser.parse_args()

# Collect all model paths in the model zoo
Expand All @@ -284,7 +301,7 @@ def main():
if (args.m):
models_to_run = [args.m]

target_model_paths = []
target_model_paths = set()
for name in models_to_run:
if name not in all_model_names:
print(
Expand All @@ -293,26 +310,27 @@ def main():
difflib.get_close_matches(name, all_model_names,
len(all_model_names)))
return
target_model_paths += [m for m in all_model_paths if name in m]
for m in all_model_paths:
if name in m:
target_model_paths.add(m)

# Start processing the models.
results = Parallel(n_jobs=args.njobs,
verbose=1)(delayed(pull_and_check_model)(
path, args.compile_args, args.pull_models, args.k)
for path in target_model_paths)
results = Parallel(n_jobs=args.njobs, verbose=1)(
delayed(pull_and_check_model)(path, args.compile_args, args.k)
for path in target_model_paths)

# Report the results.
tested_models = [r[1] for r in results if r[0] != NO_TEST]
tested_models = {r[1] for r in results if r[0] != NO_TEST}
print("{} models tested: {}\n".format(len(tested_models),
', '.join(tested_models)))
passed_models = [r[1] for r in results if r[0] == TEST_PASSED]
passed_models = {r[1] for r in results if r[0] == TEST_PASSED}
print("{} models passed: {}\n".format(len(passed_models),
', '.join(passed_models)))
if len(passed_models) != len(tested_models):
failed_models = [r[1] for r in results if r[0] == TEST_FAILED]
msg = "{} model failed: {}\n".format(len(failed_models),
', '.join(failed_models))
if args.assertion:
failed_models = {r[1] for r in results if r[0] == TEST_FAILED}
msg = "{} models failed: {}\n".format(len(failed_models),
', '.join(failed_models))
if args.a:
raise AssertionError(msg)
else:
print(msg)
Expand Down
6 changes: 3 additions & 3 deletions utils/RunONNXModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ def generate_random_input(model, input_shapes):
"of the {} input is unknown.".format(ordinal(i + 1)),
"Use --shape_info to set.")
print(shape_proto)
exit()
exit(1)
else:
print("The shape of the {} input".format(ordinal(i + 1)),
"is unknown. Use --shape_info to set.")
print(shape_proto)
exit()
exit(1)
rinput = np.random.uniform(-1.0, 1.0,
explicit_shape).astype(np.float32)
print(" - {} input's shape {}".format(ordinal(i + 1), rinput.shape))
Expand Down Expand Up @@ -371,7 +371,7 @@ def main():
ref_outs = read_output_from_refs(model, args.data_folder)
else:
print("Invalid verify option")
exit()
exit(1)

# For each output tensor, compare results.
for i, name in enumerate(output_names):
Expand Down