Skip to content

Commit

Permalink
enable additional flags for tank test models (huggingface#866)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex <alexander@nod-labs.com>
  • Loading branch information
aldesilv and Alex authored Feb 2, 2023
1 parent 5c7deb3 commit b3fc0f2
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
4 changes: 4 additions & 0 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def get_model_specific_args():
ms_args = []
if shark_args.enable_conv_transform == True:
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
if shark_args.enable_img2col_transform == True:
ms_args += ["--iree-flow-enable-conv-img2col-transform"]
if shark_args.use_winograd == True:
ms_args += ["--iree-flow-enable-conv-winograd-transform"]
return ms_args


Expand Down
14 changes: 14 additions & 0 deletions shark/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,18 @@ def dir_file(path):
help="Enables the --iree-flow-enable-conv-nchw-to-nhwc-transform flag.",
)

parser.add_argument(
"--enable_img2col_transform",
default=False,
action="store_true",
help="Enables the --iree-flow-enable-conv-img2col-transform flag.",
)

parser.add_argument(
"--use_winograd",
default=False,
action="store_true",
help="Enables the --iree-flow-enable-conv-winograd-transform flag.",
)

shark_args, unknown = parser.parse_known_args()
8 changes: 4 additions & 4 deletions tank/all_models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nh
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390",""
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
microsoft/resnet-50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344",""
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/388","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/343","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc,True,False,True,"",""
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,False,False,False,"","macos"
efficientnet-v2-s,mhlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
4 changes: 4 additions & 0 deletions tank/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def create_and_check_module(self, dynamic, device):
shark_args.enable_conv_transform = True
else:
shark_args.enable_conv_transform = False
if "img2col" in self.config["flags"]:
shark_args.enable_img2col_transform = True
if "winograd" in self.config["flags"]:
shark_args.use_winograd = True

model, func_name, inputs, golden_out = download_model(
self.config["model_name"],
Expand Down

0 comments on commit b3fc0f2

Please sign in to comment.