From b3fc0f29ccaa2f8021ebe14ea25670a5ede99afd Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Thu, 2 Feb 2023 11:19:33 -0800 Subject: [PATCH] enable additional flags for tank test models (#866) Co-authored-by: Alex --- shark/iree_utils/compile_utils.py | 4 ++++ shark/parser.py | 14 ++++++++++++++ tank/all_models.csv | 8 ++++---- tank/test_models.py | 4 ++++ 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index d19259df9ec89..8251f9285e458 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -81,6 +81,10 @@ def get_model_specific_args(): ms_args = [] if shark_args.enable_conv_transform == True: ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"] + if shark_args.enable_img2col_transform == True: + ms_args += ["--iree-flow-enable-conv-img2col-transform"] + if shark_args.use_winograd == True: + ms_args += ["--iree-flow-enable-conv-winograd-transform"] return ms_args diff --git a/shark/parser.py b/shark/parser.py index ee3ed3ba27623..97f38ff2f33a9 100644 --- a/shark/parser.py +++ b/shark/parser.py @@ -112,4 +112,18 @@ def dir_file(path): help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.", ) +parser.add_argument( + "--enable_img2col_transform", + default=False, + action="store_true", + help="Enables the --iree-flow-enable-conv-img2col-transform flag.", +) + +parser.add_argument( + "--use_winograd", + default=False, + action="store_true", + help="Enables the --iree-flow-enable-conv-winograd-transform flag.", +) + shark_args, unknown = parser.parse_known_args() diff --git a/tank/all_models.csv b/tank/all_models.csv index 9659bbe569d52..15da7581651af 100644 --- a/tank/all_models.csv +++ b/tank/all_models.csv @@ -22,15 +22,15 @@ facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nh google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311","" microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","" microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"","" -microsoft/resnet-50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos" +microsoft/resnet-50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos" google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","" mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos" nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos" -resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos" +resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos" resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos" resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos" -resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc,True,False,True,"","" +resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"","" squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos" -wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos" +wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos" efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos" mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos" diff --git a/tank/test_models.py b/tank/test_models.py index 2013cdff43aab..ced6d33e3f474 100644 --- a/tank/test_models.py +++ b/tank/test_models.py @@ -143,6 +143,10 @@ def create_and_check_module(self, dynamic, device): shark_args.enable_conv_transform = True else: shark_args.enable_conv_transform = False + if "img2col" in self.config["flags"]: + shark_args.enable_img2col_transform = True + if "winograd" in self.config["flags"]: + shark_args.use_winograd = True model, func_name, inputs, golden_out = download_model( self.config["model_name"],