Skip to content

Commit

Permalink
improved code
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Dec 27, 2023
1 parent 0baefc4 commit 4ce9261
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 83 deletions.
24 changes: 3 additions & 21 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,26 +1552,6 @@ def infer_task_from_model(

return task

@classmethod
def get_config_from_model(
cls,
model_name_or_path: Union[str, Path],
subfolder: str = "",
revision: Optional[str] = None,
):
full_model_path = Path(model_name_or_path) / subfolder

config_path = full_model_path / CONFIG_NAME

if not full_model_path.is_dir():
config_path = huggingface_hub.hf_hub_download(
model_name_or_path, CONFIG_NAME, subfolder=subfolder, revision=revision
)

model_config = PretrainedConfig.from_json_file(config_path)

return model_config

@classmethod
def infer_library_from_model(
cls,
Expand Down Expand Up @@ -1619,7 +1599,9 @@ def infer_library_from_model(
if "model_index.json" in all_files:
library_name = "diffusers"
elif CONFIG_NAME in all_files:
model_config = TasksManager.get_config_from_model(model_name_or_path, subfolder, revision)
model_config = PretrainedConfig.from_pretrained(
model_name_or_path, subfolder=subfolder, revision=revision
)

if hasattr(model_config, "pretrained_cfg") or hasattr(model_config, "architecture"):
library_name = "timm"
Expand Down
30 changes: 12 additions & 18 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
from typing import TYPE_CHECKING, Optional, Union

from huggingface_hub import HfApi, HfFolder
from transformers import AutoConfig, add_start_docstrings
from transformers import AutoConfig, PretrainedConfig, add_start_docstrings

from .exporters import TasksManager
from .utils import CONFIG_NAME


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers import PreTrainedModel, TFPreTrainedModel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,7 +80,7 @@ class OptimizedModel(PreTrainedModel):
base_model_prefix = "optimized_model"
config_name = CONFIG_NAME

def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "PretrainedConfig"):
def __init__(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], config: PretrainedConfig):
super().__init__()
self.model = model
self.config = config
Expand Down Expand Up @@ -225,9 +224,9 @@ def _load_config(
force_download: bool = False,
subfolder: str = "",
trust_remote_code: bool = False,
) -> "PretrainedConfig":
) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path=config_name_or_path,
revision=revision,
cache_dir=cache_dir,
Expand All @@ -239,7 +238,7 @@ def _load_config(
except OSError as e:
# if config not found in subfolder, search for it at the top level
if subfolder != "":
config = AutoConfig.from_pretrained(
config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path=config_name_or_path,
revision=revision,
cache_dir=cache_dir,
Expand All @@ -258,7 +257,7 @@ def _load_config(
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand All @@ -274,7 +273,7 @@ def _from_pretrained(
def _from_transformers(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand All @@ -294,7 +293,7 @@ def _from_transformers(
def _export(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand All @@ -319,7 +318,7 @@ def from_pretrained(
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
subfolder: str = "",
config: Optional["PretrainedConfig"] = None,
config: Optional[PretrainedConfig] = None,
local_files_only: bool = False,
trust_remote_code: bool = False,
revision: Optional[str] = None,
Expand All @@ -346,19 +345,14 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir)

if library_name == "timm":
config = TasksManager.get_config_from_model(model_id, subfolder, revision)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
config = AutoConfig.from_pretrained(
config = PretrainedConfig.from_pretrained(
os.path.join(model_id, subfolder, CONFIG_NAME), trust_remote_code=trust_remote_code
)
elif CONFIG_NAME in os.listdir(model_id):
config = AutoConfig.from_pretrained(
config = PretrainedConfig.from_pretrained(
os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code
)
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def _from_transformers(
use_io_binding: Optional[bool] = None,
task: Optional[str] = None,
) -> "ORTModel":
return cls._from_export(
return cls._export(
model_id=model_id,
config=config,
revision=revision,
Expand Down
6 changes: 0 additions & 6 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,6 @@
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"regnet": "facebook/regnet-y-040",
"resnet": "microsoft/resnet-50",
"resnext26ts": "timm/resnext26ts.ra2_in1k",
"resnext50-32x4d": "timm/resnext50_32x4d.tv2_in1k",
"resnext50d-32x4d": "timm/resnext50d_32x4d.bt_in1k",
"resnext101-32x4d": "timm/resnext101_32x4d.gluon_in1k",
"resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k",
"resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k",
"roberta": "roberta-base",
"roformer": "junnyu/roformer_chinese_base",
"sam": "facebook/sam-vit-base",
Expand Down
75 changes: 44 additions & 31 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,19 +2715,27 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin):
"vit",
]

TIMM_SUPPORTED_ARCHITECTURES = [
"resnext26ts",
"resnext50-32x4d",
"resnext50d-32x4d",
"resnext101-32x4d",
"resnext101-32x8d",
"resnext101-64x4d",
]
TIMM_SUPPORTED_ARCHITECTURES = ["default-timm-config"]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
ORTMODEL_CLASS = ORTModelForImageClassification
TASK = "image-classification"

def _get_model_ids(self, model_arch):
model_ids = MODEL_NAMES[model_arch]
if isinstance(model_ids, dict):
model_ids = list(model_ids.keys())
else:
model_ids = [model_ids]
return model_ids

def _get_onnx_model_dir(self, model_id, model_arch, test_name):
onnx_model_dir = self.onnx_model_dirs[test_name]
if isinstance(MODEL_NAMES[model_arch], dict):
onnx_model_dir = onnx_model_dir[model_id]

return onnx_model_dir

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForImageClassification.from_pretrained(MODEL_NAMES["t5"], export=True)
Expand All @@ -2743,37 +2751,42 @@ def test_compare_to_timm(self, model_arch):

self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageClassification.from_pretrained(self.onnx_model_dirs[model_arch])
model_ids = self._get_model_ids(model_arch)
for model_id in model_ids:
onnx_model = ORTModelForImageClassification.from_pretrained(
self._get_onnx_model_dir(model_id, model_arch, model_arch)
)

self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession)
self.assertIsInstance(onnx_model.config, PretrainedConfig)
self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession)
self.assertIsInstance(onnx_model.config, PretrainedConfig)

set_seed(SEED)
timm_model = timm.create_model(model_id, pretrained=True)
timm_model = timm_model.eval()
set_seed(SEED)
timm_model = timm.create_model(model_id, pretrained=True)
timm_model = timm_model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(timm_model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(timm_model)
transforms = timm.data.create_transform(**data_config, is_training=False)

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
image = Image.open(requests.get(url, stream=True).raw)
inputs = transforms(image).unsqueeze(0)
url = (
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
)
image = Image.open(requests.get(url, stream=True).raw)
inputs = transforms(image).unsqueeze(0)

with torch.no_grad():
timm_outputs = timm_model(inputs)
with torch.no_grad():
timm_outputs = timm_model(inputs)

for input_type in ["pt", "np"]:
if input_type == "np":
inputs = inputs.cpu().detach().numpy()
onnx_outputs = onnx_model(inputs)
for input_type in ["pt", "np"]:
if input_type == "np":
inputs = inputs.cpu().detach().numpy()
onnx_outputs = onnx_model(inputs)

self.assertIn("logits", onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])
self.assertIn("logits", onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])

# compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), timm_outputs, atol=1e-4))
# compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), timm_outputs, atol=1e-4))

gc.collect()

Expand Down
43 changes: 37 additions & 6 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,43 @@
"data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"default-timm-config": {
"timm/inception_v3.tf_adv_in1k": ["image-classification"],
"timm/tf_efficientnet_b0.in1k": ["image-classification"],
"timm/resnetv2_50x1_bit.goog_distilled_in1k": ["image-classification"],
"timm/cspdarknet53.ra_in1k": ["image-classification"],
"timm/cspresnet50.ra_in1k": ["image-classification"],
"timm/cspresnext50.ra_in1k": ["image-classification"],
"timm/densenet121.ra_in1k": ["image-classification"],
"timm/dla102.in1k": ["image-classification"],
"timm/dpn107.mx_in1k": ["image-classification"],
"timm/ecaresnet101d.miil_in1k": ["image-classification"],
"timm/efficientnet_b1_pruned.in1k": ["image-classification"],
"timm/inception_resnet_v2.tf_ens_adv_in1k": ["image-classification"],
"timm/fbnetc_100.rmsp_in1k": ["image-classification"],
"timm/xception41.tf_in1k": ["image-classification"],
"timm/senet154.gluon_in1k": ["image-classification"],
"timm/seresnext26d_32x4d.bt_in1k": ["image-classification"],
"timm/hrnet_w18.ms_aug_in1k": ["image-classification"],
"timm/inception_v3.gluon_in1k": ["image-classification"],
"timm/inception_v4.tf_in1k": ["image-classification"],
"timm/mixnet_s.ft_in1k": ["image-classification"],
"timm/mnasnet_100.rmsp_in1k": ["image-classification"],
"timm/mobilenetv2_100.ra_in1k": ["image-classification"],
"timm/mobilenetv3_small_050.lamb_in1k": ["image-classification"],
"timm/nasnetalarge.tf_in1k": ["image-classification"],
"timm/tf_efficientnet_b0.ns_jft_in1k": ["image-classification"],
"timm/pnasnet5large.tf_in1k": ["image-classification"],
"timm/regnetx_002.pycls_in1k": ["image-classification"],
"timm/regnety_002.pycls_in1k": ["image-classification"],
"timm/res2net101_26w_4s.in1k": ["image-classification"],
"timm/res2next50.in1k": ["image-classification"],
"timm/resnest101e.in1k": ["image-classification"],
"timm/spnasnet_100.rmsp_in1k": ["image-classification"],
"timm/resnet18.fb_swsl_ig1b_ft_in1k": ["image-classification"],
"timm/wide_resnet101_2.tv_in1k": ["image-classification"],
"timm/tresnet_l.miil_in1k": ["image-classification"],
},
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
"detr": "hf-internal-testing/tiny-random-detr",
Expand Down Expand Up @@ -91,12 +128,6 @@
"pix2struct": "fxmarty/pix2struct-tiny-random",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"resnet": "hf-internal-testing/tiny-random-resnet",
"resnext26ts": "timm/resnext26ts.ra2_in1k",
"resnext50-32x4d": "timm/resnext50_32x4d.tv2_in1k",
"resnext50d-32x4d": "timm/resnext50d_32x4d.bt_in1k",
"resnext101-32x4d": "timm/resnext101_32x4d.gluon_in1k",
"resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k",
"resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"segformer": "hf-internal-testing/tiny-random-SegformerModel",
Expand Down

0 comments on commit 4ce9261

Please sign in to comment.