From 75c9f77eb4eb9b7eb6b8b4d69a8fff2a8fd38d4b Mon Sep 17 00:00:00 2001 From: brimoor Date: Wed, 8 May 2024 13:18:41 -0400 Subject: [PATCH] handle errors, validation, add support for tracking on individual samples --- fiftyone/utils/tracking/deepsort.py | 190 ++++++++++++++++------------ 1 file changed, 109 insertions(+), 81 deletions(-) diff --git a/fiftyone/utils/tracking/deepsort.py b/fiftyone/utils/tracking/deepsort.py index da54e98f3c8..bf5ff5e73ac 100644 --- a/fiftyone/utils/tracking/deepsort.py +++ b/fiftyone/utils/tracking/deepsort.py @@ -5,105 +5,143 @@ | `voxel51.com `_ | """ -# pylint: disable=no-member - import logging -import cv2 + +import eta.core.video as etav import fiftyone as fo -import fiftyone.zoo as foz import fiftyone.core.utils as fou +import fiftyone.core.validation as fov + +dsrt = fou.lazy_import( + "deep_sort_realtime.deepsort_tracker", + callback=lambda: fou.ensure_package("deep-sort-realtime"), +) -dsrt = fou.lazy_import("deep_sort_realtime.deepsort_tracker") logger = logging.getLogger(__name__) -class DeepSort: +class DeepSort(object): @staticmethod def track( - dataset, + sample_collection, in_field, out_field="frames.ds_tracks", max_age=5, keep_confidence=False, + skip_failures=True, progress=None, ): - """Performs object tracking using the DeepSort algorithm on a video dataset. + """Performs object tracking using the DeepSort algorithm on the given + video samples. DeepSort is an algorithm for tracking multiple objects in video streams based on deep learning techniques. It associates bounding boxes between frames and maintains tracks of objects over time. Args: - dataset: a FiftyOne dataset - in_field: the name of the field containing detections in each frame - out_field ("frames.ds_tracks"): the name of the field to store tracking - information of the detections - max_age (5): the maximum number of missed misses before a track - is deleted. + sample_collection: a + :class:`fiftyone.core.collections.SampleCollection` + in_field: the name of a frame field containing + :class:`fiftyone.core.labels.Detections` to track. The + ``"frames."`` prefix is optional + out_field ("frames.ds_tracks"): the name of a frame field to store + the output :class:`fiftyone.core.labels.Detections` with + tracking information. The ``"frames."`` prefix is optional + max_age (5): the maximum number of missed misses before a track is + deleted keep_confidence (False): whether to store the detection confidence - of the tracked objects in the out_field - progress (None): whether to display a progress bar (True/False) + of the tracked objects in the ``out_field`` + skip_failures (True): whether to gracefully continue without + raising an error if tracking fails for a video + progress (False): whether to render a progress bar (True/False), + use the default value ``fiftyone.config.show_progress_bars`` + (None), or a progress callback function to invoke instead """ - if not in_field.startswith("frames.") or not out_field.startswith( - "frames." + in_field, _ = sample_collection._handle_frame_field(in_field) + out_field, _ = sample_collection._handle_frame_field(out_field) + _in_field = sample_collection._FRAMES_PREFIX + in_field + + fov.validate_video_collection(sample_collection) + fov.validate_collection_label_fields( + sample_collection, _in_field, fo.Detections + ) + + for sample in sample_collection.iter_samples( + autosave=True, progress=progress ): - raise ValueError( - "in_field and out_field must not be empty and must start with 'frames.'" - ) + try: + DeepSort.track_sample( + sample, + in_field, + out_field=out_field, + max_age=max_age, + keep_confidence=keep_confidence, + ) + except Exception as e: + if not skip_failures: + raise e - for sample in dataset.iter_samples(autosave=True, progress=progress): - tracker = dsrt.DeepSort(max_age=max_age) + logger.warning("Sample: %s\nError: %s\n", sample.id, e) - cap = cv2.VideoCapture(sample.filepath) - frames_list = [] + @staticmethod + def track_sample( + sample, + in_field, + out_field="ds_tracks", + max_age=5, + keep_confidence=False, + ): + """Performs object tracking using the DeepSort algorithm on the given + video sample. - while True: - ret, frame = cap.read() - if not ret: - break - frames_list.append(frame) + DeepSort is an algorithm for tracking multiple objects in video streams + based on deep learning techniques. It associates bounding boxes between + frames and maintains tracks of objects over time. - cap.release() + Args: + sample: a :class:`fiftyone.core.sample.Sample` + in_field: the name of the frame field containing + :class:`fiftyone.core.labels.Detections` to track + out_field ("ds_tracks"): the name of a frame field to store the + output :class:`fiftyone.core.labels.Detections` with tracking + information. The ``"frames."`` prefix is optional + max_age (5): the maximum number of missed misses before a track is + deleted + keep_confidence (False): whether to store the detection confidence + of the tracked objects in the ``out_field`` + """ + tracker = dsrt.DeepSort(max_age=max_age) - if len(frames_list) != len(sample.frames): - logger.error( - "Unable to align the captured frames with the encoded frames!" - ) - return + with etav.FFmpegVideoReader(sample.filepath) as video_reader: + for img in video_reader: + frame = sample.frames[video_reader.frame_number] + frame_width = img.shape[1] + frame_height = img.shape[0] - for frame_idx, frame in sample.frames.items(): - frame_detections = frame[in_field[len("frames.") :]] bbs = [] - extracted_detections = foz.deepcopy( - frame_detections.detections - ) - frame_width = frames_list[frame_idx - 1].shape[1] - frame_height = frames_list[frame_idx - 1].shape[0] - - for detection in extracted_detections: - coordinates = detection.bounding_box - coordinates[0] *= frame_width - coordinates[1] *= frame_height - coordinates[2] *= frame_width - coordinates[3] *= frame_height - confidence = ( - detection.confidence if detection.confidence else 0 - ) - detection_class = detection.label - - bbs.append(((coordinates), confidence, detection_class)) - tracks = tracker.update_tracks( - bbs, frame=frames_list[frame_idx - 1] - ) + if frame[in_field] is not None: + for detection in frame[in_field].detections: + bbox = detection.bounding_box + coordinates = [ + bbox[0] * frame_width, + bbox[1] * frame_height, + bbox[2] * frame_width, + bbox[3] * frame_height, + ] + confidence = detection.confidence or 0 + label = detection.label + bbs.append(((coordinates), confidence, label)) + + tracks = tracker.update_tracks(bbs, frame=img) tracked_detections = [] - - for _, track in enumerate(tracks): + for track in tracks: if not track.is_confirmed(): continue + ltrb = track.to_ltrb() x1, y1, x2, y2 = ltrb w, h = x2 - x1, y2 - y1 @@ -113,24 +151,14 @@ def track( rel_w = min(w / frame_width, 1 - rel_x) rel_h = min(h / frame_height, 1 - rel_y) + detection = fo.Detection( + label=track.get_det_class(), + bounding_box=[rel_x, rel_y, rel_w, rel_h], + index=track.track_id, + ) if keep_confidence: - tracked_detections.append( - fo.Detection( - label=track.get_det_class(), - confidence=track.get_det_conf(), - bounding_box=[rel_x, rel_y, rel_w, rel_h], - index=track.track_id, - ) - ) - else: - tracked_detections.append( - fo.Detection( - label=track.get_det_class(), - bounding_box=[rel_x, rel_y, rel_w, rel_h], - index=track.track_id, - ) - ) - - frame[out_field[len("frames.") :]] = fo.Detections( - detections=tracked_detections - ) + detection.confidence = track.get_det_conf() + + tracked_detections.append(detection) + + frame[out_field] = fo.Detections(detections=tracked_detections)