Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSort tracking enhancements #4372

Merged
merged 2 commits into from
May 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 109 additions & 81 deletions fiftyone/utils/tracking/deepsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,105 +5,143 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
# pylint: disable=no-member

import logging
import cv2

import eta.core.video as etav

import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.core.utils as fou
import fiftyone.core.validation as fov

dsrt = fou.lazy_import(
"deep_sort_realtime.deepsort_tracker",
callback=lambda: fou.ensure_package("deep-sort-realtime"),
)

dsrt = fou.lazy_import("deep_sort_realtime.deepsort_tracker")

logger = logging.getLogger(__name__)


class DeepSort:
class DeepSort(object):
@staticmethod
def track(
dataset,
sample_collection,
in_field,
out_field="frames.ds_tracks",
max_age=5,
keep_confidence=False,
skip_failures=True,
progress=None,
):
"""Performs object tracking using the DeepSort algorithm on a video dataset.
"""Performs object tracking using the DeepSort algorithm on the given
video samples.

DeepSort is an algorithm for tracking multiple objects in video streams
based on deep learning techniques. It associates bounding boxes between
frames and maintains tracks of objects over time.

Args:
dataset: a FiftyOne dataset
in_field: the name of the field containing detections in each frame
out_field ("frames.ds_tracks"): the name of the field to store tracking
information of the detections
max_age (5): the maximum number of missed misses before a track
is deleted.
sample_collection: a
:class:`fiftyone.core.collections.SampleCollection`
in_field: the name of a frame field containing
:class:`fiftyone.core.labels.Detections` to track. The
``"frames."`` prefix is optional
out_field ("frames.ds_tracks"): the name of a frame field to store
the output :class:`fiftyone.core.labels.Detections` with
tracking information. The ``"frames."`` prefix is optional
max_age (5): the maximum number of missed misses before a track is
deleted
keep_confidence (False): whether to store the detection confidence
of the tracked objects in the out_field
progress (None): whether to display a progress bar (True/False)
of the tracked objects in the ``out_field``
skip_failures (True): whether to gracefully continue without
raising an error if tracking fails for a video
progress (False): whether to render a progress bar (True/False),
use the default value ``fiftyone.config.show_progress_bars``
(None), or a progress callback function to invoke instead
"""
if not in_field.startswith("frames.") or not out_field.startswith(
"frames."
in_field, _ = sample_collection._handle_frame_field(in_field)
out_field, _ = sample_collection._handle_frame_field(out_field)
_in_field = sample_collection._FRAMES_PREFIX + in_field

fov.validate_video_collection(sample_collection)
fov.validate_collection_label_fields(
sample_collection, _in_field, fo.Detections
)

for sample in sample_collection.iter_samples(
autosave=True, progress=progress
):
raise ValueError(
"in_field and out_field must not be empty and must start with 'frames.'"
)
try:
DeepSort.track_sample(
sample,
in_field,
out_field=out_field,
max_age=max_age,
keep_confidence=keep_confidence,
)
except Exception as e:
if not skip_failures:
raise e

for sample in dataset.iter_samples(autosave=True, progress=progress):
tracker = dsrt.DeepSort(max_age=max_age)
logger.warning("Sample: %s\nError: %s\n", sample.id, e)

cap = cv2.VideoCapture(sample.filepath)
frames_list = []
@staticmethod
def track_sample(
sample,
in_field,
out_field="ds_tracks",
max_age=5,
keep_confidence=False,
):
"""Performs object tracking using the DeepSort algorithm on the given
video sample.

while True:
ret, frame = cap.read()
if not ret:
break
frames_list.append(frame)
DeepSort is an algorithm for tracking multiple objects in video streams
based on deep learning techniques. It associates bounding boxes between
frames and maintains tracks of objects over time.

cap.release()
Args:
sample: a :class:`fiftyone.core.sample.Sample`
in_field: the name of the frame field containing
:class:`fiftyone.core.labels.Detections` to track
out_field ("ds_tracks"): the name of a frame field to store the
output :class:`fiftyone.core.labels.Detections` with tracking
information. The ``"frames."`` prefix is optional
max_age (5): the maximum number of missed misses before a track is
deleted
keep_confidence (False): whether to store the detection confidence
of the tracked objects in the ``out_field``
"""
tracker = dsrt.DeepSort(max_age=max_age)

if len(frames_list) != len(sample.frames):
logger.error(
"Unable to align the captured frames with the encoded frames!"
)
return
with etav.FFmpegVideoReader(sample.filepath) as video_reader:
for img in video_reader:
frame = sample.frames[video_reader.frame_number]
frame_width = img.shape[1]
frame_height = img.shape[0]

for frame_idx, frame in sample.frames.items():
frame_detections = frame[in_field[len("frames.") :]]
bbs = []
extracted_detections = foz.deepcopy(
frame_detections.detections
)
frame_width = frames_list[frame_idx - 1].shape[1]
frame_height = frames_list[frame_idx - 1].shape[0]

for detection in extracted_detections:
coordinates = detection.bounding_box
coordinates[0] *= frame_width
coordinates[1] *= frame_height
coordinates[2] *= frame_width
coordinates[3] *= frame_height
confidence = (
detection.confidence if detection.confidence else 0
)
detection_class = detection.label

bbs.append(((coordinates), confidence, detection_class))

tracks = tracker.update_tracks(
bbs, frame=frames_list[frame_idx - 1]
)
if frame[in_field] is not None:
for detection in frame[in_field].detections:
bbox = detection.bounding_box
coordinates = [
bbox[0] * frame_width,
bbox[1] * frame_height,
bbox[2] * frame_width,
bbox[3] * frame_height,
]
confidence = detection.confidence or 0
label = detection.label
bbs.append(((coordinates), confidence, label))

tracks = tracker.update_tracks(bbs, frame=img)

tracked_detections = []

for _, track in enumerate(tracks):
for track in tracks:
if not track.is_confirmed():
continue

ltrb = track.to_ltrb()
x1, y1, x2, y2 = ltrb
w, h = x2 - x1, y2 - y1
Expand All @@ -113,24 +151,14 @@ def track(
rel_w = min(w / frame_width, 1 - rel_x)
rel_h = min(h / frame_height, 1 - rel_y)

detection = fo.Detection(
label=track.get_det_class(),
bounding_box=[rel_x, rel_y, rel_w, rel_h],
index=track.track_id,
)
if keep_confidence:
tracked_detections.append(
fo.Detection(
label=track.get_det_class(),
confidence=track.get_det_conf(),
bounding_box=[rel_x, rel_y, rel_w, rel_h],
index=track.track_id,
)
)
else:
tracked_detections.append(
fo.Detection(
label=track.get_det_class(),
bounding_box=[rel_x, rel_y, rel_w, rel_h],
index=track.track_id,
)
)

frame[out_field[len("frames.") :]] = fo.Detections(
detections=tracked_detections
)
detection.confidence = track.get_det_conf()

tracked_detections.append(detection)

frame[out_field] = fo.Detections(detections=tracked_detections)
Loading