Skip to content

Commit

Permalink
Make the onnx importer more robust for internal/external and large mo…
Browse files Browse the repository at this point in the history
…dels (#2794)

Fix for #2765

The onnx docs say that you can't do shape inference using the in-memory
API for models > 2 GB. This fix replaces that API with the file-based
API. Since the new API generates an intermediate file, also added a
--keep switch to keep that file, which I delete by default.

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
  • Loading branch information
daveliddell and Dave Liddell authored Feb 1, 2024
1 parent 34f6948 commit 04be6ba
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 6 deletions.
106 changes: 100 additions & 6 deletions python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
python -m torch_mlir.tools.import_onnx ...
"""
import argparse
import os
from pathlib import Path
import shutil
import sys

import onnx
Expand All @@ -27,8 +29,8 @@
)


def main(args):
model_proto = load_onnx_model(args.input_file)
def main(args: argparse.Namespace):
model_proto = load_onnx_model(args)
context = Context()
torch_d.register_dialect(context)
model_info = onnx_importer.ModelInfo(model_proto)
Expand All @@ -48,13 +50,84 @@ def main(args):
print(m.get_asm(assume_verified=not args.no_verify))


def load_onnx_model(file_path: Path) -> onnx.ModelProto:
raw_model = onnx.load(file_path)
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# Do shape inference two ways. First, attempt in-memory to avoid redundant
# loading and the need for writing a temporary file somewhere. If that
# fails, typically because of the 2 GB protobuf size limit, try again via
# files. See
# https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#shape-inference-a-large-onnx-model-2gb
# for details about the file-based technique.

# Make a temp dir for all the temp files we'll be generating as a side
# effect of infering shapes. For now, the only file is a new .onnx holding
# the revised model with shapes.
#
# TODO: If the program temp_dir is None, we should be using an ephemeral
# temp directory instead of a hard-coded path in order to avoid data races
# by default.
input_dir = os.path.dirname(os.path.abspath(args.input_file))
temp_dir = (
Path(input_dir if args.temp_dir is None else args.temp_dir)
/ "onnx-importer-temp"
)
shutil.rmtree(temp_dir, ignore_errors=True)
temp_dir.mkdir(exist_ok=True)

# Load the model, with possible external data coming from the default
# location, or the location specified on the conmand line.
if args.data_dir is None:
raw_model = onnx.load(args.input_file)
else:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, args.data_dir)

# Run the checker to test whether the file is above the threshold for
# in-memory shape inference. If not, go ahead and do the shape inference.
try:
onnx.checker.check_model(raw_model)
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
return inferred_model
except ValueError:
pass

# The following code was an attempt to work around the bug where models
# with external data produce invalid output shapes after infer_shapes_path.
# It works with small models but threw an error for llama seeming to
# indicate that the protobuf is corrupt.
#
# temp_raw_file = temp_dir / "raw.onnx"
# onnx.save(raw_model, temp_raw_file, save_as_external_data=False)
# onnx.shape_inference.infer_shapes_path(temp_raw_file, temp_inferred_file)
# inferred_model = onnx.load(temp_inferred_file)

# Model is too big for in-memory inference: do file-based shape inference
# to a temp file.
temp_inferred_file = temp_dir / "inferred.onnx"
onnx.shape_inference.infer_shapes_path(args.input_file, temp_inferred_file)

# Sanity check the shape-inferred model to be sure we have a good model
# for the importer. This call uses the file-based method, as the
# in-memory method (passing the loaded model) fails due to the 2 GB limit.
#
# TODO: this call throws an exception because it can't find the external
# data files, and there doesn't appear to be a way to let the checker know
# where to find them.
#
# onnx.checker.check_model(temp_inferred_file)

# Load the temp file and the external data.
inferred_model = onnx.load(temp_inferred_file, load_external_data=False)
data_dir = Path(input_dir if args.temp_dir is None else args.data_dir)
onnx.load_external_data_for_model(inferred_model, data_dir)

# Remove the inferred shape file unless asked to keep it
if not args.keep_temps:
shutil.rmtree(temp_dir)

return inferred_model


def parse_arguments(argv=None):
def parse_arguments(argv=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool")
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
parser.add_argument(
Expand All @@ -65,6 +138,27 @@ def parse_arguments(argv=None):
action="store_true",
help="Disable verification prior to printing",
)
parser.add_argument(
"--keep-temps", action="store_true", help="Keep intermediate files"
)
parser.add_argument(
"--temp-dir",
help="Pre-existing directory in which to create temporary files."
' For example, to place temporaries under the directory "foo/bar",'
' specify --temp-dir=foo/bar. "foo/bar" must already exist.'
" Defaults to the directory of the input file.",
type=Path,
)
parser.add_argument(
"--data-dir",
help="Path between CWD and the base directory of the data,"
" excluding the directories given in the 'location' argument of "
" convert_model_to_external_data. For example, if 'location' was"
' "data/data.bin" and the relative path from CWD to that .bin file is'
' a/b/data/data.bin, then set data-dir to "a/b".'
" Defaults to the directory of the input file.",
type=Path,
)
args = parser.parse_args(argv)
return args

Expand Down
144 changes: 144 additions & 0 deletions test/python/onnx_importer/command_line_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Based on code Copyright (c) Advanced Micro Devices, Inc.
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

# RUN: %PYTHON %s --output %t

from pathlib import Path

import logging
import shutil
import sys
import subprocess
import unittest
import unittest.mock

import onnx

from torch_mlir.tools.import_onnx import __main__

# For ONNX models

import numpy
from onnx import numpy_helper, TensorProto
from onnx.helper import (
make_model, make_node, make_graph,
make_tensor_value_info)
from onnx.external_data_helper import convert_model_to_external_data
from onnx.checker import check_model

# Accept the output path on the command line or default to a sibling
# to this file. We have to pop this off explicitly or else unittest
# won't understand.
if len(sys.argv) > 1 and sys.argv[1] == "--output":
OUTPUT_PATH = Path(sys.argv[2])
del sys.argv[1:3]
else:
OUTPUT_PATH = Path(__file__).resolve().parent / "output"

OUTPUT_PATH.mkdir(parents=True, exist_ok=True)


def const_model() -> onnx.ModelProto:
# Note: data_path must be relative to model_file

const = make_node(
'Constant', [], ['c_shape'], 'const',
value=numpy_helper.from_array(numpy.array([4], dtype=numpy.int64)))
cofshape = make_node(
'ConstantOfShape', ['c_shape'], ['c_out'], 'cofshape',
value=numpy_helper.from_array(numpy.array([1], dtype=numpy.int64)))

outval = make_tensor_value_info('c_out', TensorProto.INT64, [None])
graph = make_graph([const, cofshape], 'constgraph', [], [outval])

onnx_model = make_model(graph)
check_model(onnx_model)
return onnx_model


def linear_model() -> onnx.ModelProto:
# initializers
k_dim = 32
value = numpy.arange(k_dim).reshape([k_dim, 1])
value = numpy.asarray(value, dtype=numpy.float32)
A = numpy_helper.from_array(value, name='A')

value = numpy.array([0.4], dtype=numpy.float32).reshape([1, 1])
C = numpy_helper.from_array(value, name='C')

# the part which does not change
X = make_tensor_value_info('X', TensorProto.FLOAT, [1, k_dim])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None, None])
node1 = make_node('MatMul', ['X', 'A'], ['AX'])
node2 = make_node('Add', ['AX', 'C'], ['Y'])
graph = make_graph([node1, node2], 'lr', [X], [Y], [A, C])
onnx_model = make_model(graph)
check_model(onnx_model)
return onnx_model


ALL_MODELS = [
const_model,
linear_model
]


class CommandLineTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir = OUTPUT_PATH / "command-line"
shutil.rmtree(cls.test_dir, ignore_errors=True)
cls.test_dir.mkdir(parents=True, exist_ok=True)

def get_run_path(self, model_name: str) -> Path:
run_path = CommandLineTest.test_dir / model_name
run_path.mkdir(exist_ok=True)
return run_path

def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str):
run_path = self.get_run_path(model_name)
model_file = run_path / f"{model_name}-i.onnx"
mlir_file = run_path / f"{model_name}-i.torch.mlir"
onnx.save(onnx_model, model_file)
args = __main__.parse_arguments([
str(model_file), "-o", str(mlir_file)])
__main__.main(args)

def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
run_path = self.get_run_path(model_name)
model_file = run_path / f"{model_name}-e.onnx"
mlir_file = run_path / f"{model_name}-e.torch.mlir"
data_dir_name = f"{model_name}-data"
model_data_dir = run_path / data_dir_name
model_data_dir.mkdir(exist_ok=True)
convert_model_to_external_data(
onnx_model, all_tensors_to_one_file=True,
location=data_dir_name + "/data.bin",
size_threshold=48,
convert_attribute=True)
onnx.save(onnx_model, model_file)
temp_dir = run_path / "temp"
temp_dir.mkdir(exist_ok=True)
args = __main__.parse_arguments([
str(model_file), "-o", str(mlir_file), "--keep-temps", "--temp-dir",
str(temp_dir), "--data-dir", str(run_path)])
__main__.main(args)

def test_all(self):
for model_func in ALL_MODELS:
model_name = model_func.__name__
model = model_func()
with self.subTest(f"model {model_name}", model_name=model_name):
with self.subTest("Internal data"):
self.run_model_intern(model, model_name)
with self.subTest("External data"):
self.run_model_extern(model, model_name)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit 04be6ba

Please sign in to comment.