Skip to content

Commit

Permalink
support task map attribute for yolo >= 8.0.44
Browse files Browse the repository at this point in the history
  • Loading branch information
wadhah101 committed Feb 3, 2024
1 parent 0350758 commit f297381
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
6 changes: 6 additions & 0 deletions tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
hub_id = "ultralyticsplus/yolov8s"


# for ultralytics < 8.0.44
def test_load_from_hub():
path = download_from_hub(hub_id)


# for ultralytics >= 8.0.44
def test_load_from_hub_yolo_8_0_44():
model = YOLO("keremberke/yolov8n-table-extraction")


def test_yolo_from_hub():
model = YOLO(hub_id)

Expand Down
28 changes: 22 additions & 6 deletions ultralyticsplus/ultralytics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,28 @@ def _load_from_hf_hub(self, weights: str, hf_token=None):
self.task = self.model.args["task"]
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
(
self.ModelClass,
self.TrainerClass,
self.ValidatorClass,
self.PredictorClass,
) = self._assign_ops_from_task()

# for loading model with ultralytics <8.0.44
if hasattr(self, "_assign_ops_from_task"):
(
self.ModelClass,
self.TrainerClass,
self.ValidatorClass,
self.PredictorClass,
) = self.task_map[self.task]

# for loading model with ultralytics >=8.0.44
else:
if self.task not in self.task_map:
raise ValueError(
f"Task '{self.task}' not supported. Supported tasks: {list(self.task_map.keys())}"
)
(
self.ModelClass,
self.TrainerClass,
self.ValidatorClass,
self.PredictorClass,
) = self.task_map[self.task]


def render_result(
Expand Down

0 comments on commit f297381

Please sign in to comment.