diff --git a/fiftyone/utils/tracking/deepsort.py b/fiftyone/utils/tracking/deepsort.py index da54e98f3c8..0c688cf0f60 100644 --- a/fiftyone/utils/tracking/deepsort.py +++ b/fiftyone/utils/tracking/deepsort.py @@ -5,105 +5,143 @@ | `voxel51.com `_ | """ -# pylint: disable=no-member - import logging -import cv2 + +import eta.core.video as etav import fiftyone as fo -import fiftyone.zoo as foz import fiftyone.core.utils as fou +import fiftyone.core.validation as fov + +dsrt = fou.lazy_import( + "deep_sort_realtime.deepsort_tracker", + callback=lambda: fou.ensure_package("deep-sort-realtime"), +) -dsrt = fou.lazy_import("deep_sort_realtime.deepsort_tracker") logger = logging.getLogger(__name__) -class DeepSort: +class DeepSort(object): @staticmethod def track( - dataset, + sample_collection, in_field, out_field="frames.ds_tracks", max_age=5, keep_confidence=False, + skip_failures=True, progress=None, ): - """Performs object tracking using the DeepSort algorithm on a video dataset. + """Performs object tracking using the DeepSort algorithm on the given + video samples. DeepSort is an algorithm for tracking multiple objects in video streams based on deep learning techniques. It associates bounding boxes between frames and maintains tracks of objects over time. Args: - dataset: a FiftyOne dataset - in_field: the name of the field containing detections in each frame - out_field ("frames.ds_tracks"): the name of the field to store tracking - information of the detections - max_age (5): the maximum number of missed misses before a track - is deleted. + sample_collection: a + :class:`fiftyone.core.collections.SampleCollection` + in_field: the name of a frame field containing + :class:`fiftyone.core.labels.Detections` to track. The + ``"frames."`` prefix is optional + out_field ("frames.ds_tracks"): the name of a frame field to store + the output :class:`fiftyone.core.labels.Detections` with + tracking information. The ``"frames."`` prefix is optional + max_age (5): the maximum number of missed misses before a track is + deleted keep_confidence (False): whether to store the detection confidence - of the tracked objects in the out_field - progress (None): whether to display a progress bar (True/False) + of the tracked objects in the ``out_field`` + skip_failures (True): whether to gracefully continue without + raising an error if tracking fails for a video + progress (False): whether to render a progress bar (True/False), + use the default value ``fiftyone.config.show_progress_bars`` + (None), or a progress callback function to invoke instead """ - if not in_field.startswith("frames.") or not out_field.startswith( - "frames." - ): - raise ValueError( - "in_field and out_field must not be empty and must start with 'frames.'" - ) - - for sample in dataset.iter_samples(autosave=True, progress=progress): - tracker = dsrt.DeepSort(max_age=max_age) - - cap = cv2.VideoCapture(sample.filepath) - frames_list = [] - - while True: - ret, frame = cap.read() - if not ret: - break - frames_list.append(frame) - - cap.release() - - if len(frames_list) != len(sample.frames): - logger.error( - "Unable to align the captured frames with the encoded frames!" + in_field, _ = sample_collection._handle_frame_field(in_field) + out_field, _ = sample_collection._handle_frame_field(out_field) + _in_field = sample_collection._FRAMES_PREFIX + in_field + + fov.validate_video_collection(sample_collection) + fov.validate_collection_label_fields( + sample_collection, _in_field, fo.Detections + ) + + view = sample_collection.select_fields(_in_field) + + for sample in view.iter_samples(autosave=True, progress=progress): + try: + DeepSort.track_sample( + sample, + in_field, + out_field=out_field, + max_age=max_age, + keep_confidence=keep_confidence, ) - return + except Exception as e: + if not skip_failures: + raise e - for frame_idx, frame in sample.frames.items(): - frame_detections = frame[in_field[len("frames.") :]] - bbs = [] - extracted_detections = foz.deepcopy( - frame_detections.detections - ) - frame_width = frames_list[frame_idx - 1].shape[1] - frame_height = frames_list[frame_idx - 1].shape[0] - - for detection in extracted_detections: - coordinates = detection.bounding_box - coordinates[0] *= frame_width - coordinates[1] *= frame_height - coordinates[2] *= frame_width - coordinates[3] *= frame_height - confidence = ( - detection.confidence if detection.confidence else 0 - ) - detection_class = detection.label + logger.warning("Sample: %s\nError: %s\n", sample.id, e) - bbs.append(((coordinates), confidence, detection_class)) + @staticmethod + def track_sample( + sample, + in_field, + out_field="ds_tracks", + max_age=5, + keep_confidence=False, + ): + """Performs object tracking using the DeepSort algorithm on the given + video sample. - tracks = tracker.update_tracks( - bbs, frame=frames_list[frame_idx - 1] - ) + DeepSort is an algorithm for tracking multiple objects in video streams + based on deep learning techniques. It associates bounding boxes between + frames and maintains tracks of objects over time. - tracked_detections = [] + Args: + sample: a :class:`fiftyone.core.sample.Sample` + in_field: the name of the frame field containing + :class:`fiftyone.core.labels.Detections` to track + out_field ("ds_tracks"): the name of a frame field to store the + output :class:`fiftyone.core.labels.Detections` with tracking + information. The ``"frames."`` prefix is optional + max_age (5): the maximum number of missed misses before a track is + deleted + keep_confidence (False): whether to store the detection confidence + of the tracked objects in the ``out_field`` + """ + tracker = dsrt.DeepSort(max_age=max_age) + + with etav.FFmpegVideoReader(sample.filepath) as video_reader: + for img in video_reader: + frame = sample.frames[video_reader.frame_number] + frame_width = img.shape[1] + frame_height = img.shape[0] + + bbs = [] - for _, track in enumerate(tracks): + if frame[in_field] is not None: + for detection in frame[in_field].detections: + bbox = detection.bounding_box + coordinates = [ + bbox[0] * frame_width, + bbox[1] * frame_height, + bbox[2] * frame_width, + bbox[3] * frame_height, + ] + confidence = detection.confidence or 0 + label = detection.label + bbs.append(((coordinates), confidence, label)) + + tracks = tracker.update_tracks(bbs, frame=img) + + tracked_detections = [] + for track in tracks: if not track.is_confirmed(): continue + ltrb = track.to_ltrb() x1, y1, x2, y2 = ltrb w, h = x2 - x1, y2 - y1 @@ -113,24 +151,14 @@ def track( rel_w = min(w / frame_width, 1 - rel_x) rel_h = min(h / frame_height, 1 - rel_y) + detection = fo.Detection( + label=track.get_det_class(), + bounding_box=[rel_x, rel_y, rel_w, rel_h], + index=track.track_id, + ) if keep_confidence: - tracked_detections.append( - fo.Detection( - label=track.get_det_class(), - confidence=track.get_det_conf(), - bounding_box=[rel_x, rel_y, rel_w, rel_h], - index=track.track_id, - ) - ) - else: - tracked_detections.append( - fo.Detection( - label=track.get_det_class(), - bounding_box=[rel_x, rel_y, rel_w, rel_h], - index=track.track_id, - ) - ) - - frame[out_field[len("frames.") :]] = fo.Detections( - detections=tracked_detections - ) + detection.confidence = track.get_det_conf() + + tracked_detections.append(detection) + + frame[out_field] = fo.Detections(detections=tracked_detections)