Skip to content

Commit

Permalink
improve segmentation visualization (#17)
Browse files Browse the repository at this point in the history
* improve segmentation visualization

* Update __init__.py
  • Loading branch information
fcakyon authored Jan 17, 2023
1 parent ace59e3 commit 00b65f7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ultralyticsplus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .hf_utils import push_to_hfhub, download_from_hub
from .ultralytics_utils import YOLO, render_model_output, postprocess_classify_output

__version__ = "0.0.8"
__version__ = "0.0.9"
30 changes: 18 additions & 12 deletions ultralyticsplus/ultralytics_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
import os
from pathlib import Path
from ultralytics import YOLO as YOLOBase
from ultralytics.nn.tasks import attempt_load_one_weight
from sahi.prediction import ObjectPrediction, PredictionScore
from sahi.utils.cv import visualize_object_predictions

import numpy as np
from PIL import Image
from sahi.utils.cv import read_image_as_pil
from sahi.prediction import ObjectPrediction, PredictionScore
from sahi.utils.cv import (
get_bool_mask_from_coco_segmentation,
read_image_as_pil,
visualize_object_predictions,
)
from ultralytics import YOLO as YOLOBase
from ultralytics.nn.tasks import attempt_load_one_weight

from ultralyticsplus.hf_utils import download_from_hub

Expand Down Expand Up @@ -69,9 +72,7 @@ def _load_from_hf_hub(self, weights: str, hf_token=None):
) = self._guess_ops_from_task(self.task)


def render_model_output(
image, model: YOLO, model_output: dict
) -> Image.Image:
def render_model_output(image, model: YOLO, model_output: dict) -> Image.Image:
"""
Renders predictions on the image
Expand All @@ -84,7 +85,7 @@ def render_model_output(
Returns:
Image.Image: Image with predictions
"""
if model.overrides["task"] not in ['detect', 'segment']:
if model.overrides["task"] not in ["detect", "segment"]:
raise ValueError(
f"Model task must be either 'detect' or 'segment'. Got {model.overrides['task']}"
)
Expand All @@ -103,19 +104,24 @@ def render_model_output(
for *xyxy, conf, cls in det:
if segment:
segmentation = [segment[det_ind].ravel().tolist()]
bool_mask = get_bool_mask_from_coco_segmentation(
segmentation, width=np_image.shape[1], height=np_image.shape[0]
)
if sum(sum(bool_mask == 1)) == 0:
continue
object_prediction = ObjectPrediction.from_coco_segmentation(
segmentation=segmentation,
category_name=names[int(cls)],
category_id=int(cls),
full_shape=[np_image.shape[1], np_image.shape[0]]
full_shape=[np_image.shape[0], np_image.shape[1]],
)
object_prediction.score = PredictionScore(value=conf)
else:
object_prediction = ObjectPrediction(
bbox=xyxy,
category_name=names[int(cls)],
category_id=int(cls),
score=conf
score=conf,
)
object_predictions.append(object_prediction)
det_ind += 1
Expand All @@ -125,7 +131,7 @@ def render_model_output(
object_prediction_list=object_predictions,
)

return Image.fromarray(result['image'])
return Image.fromarray(result["image"])


def postprocess_classify_output(model: YOLO, prob: np.ndarray) -> dict:
Expand Down

0 comments on commit 00b65f7

Please sign in to comment.