Skip to content

Commit

Permalink
refactor: add timeseries helper
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <kotaro.uetake@tier4.jp>
  • Loading branch information
ktro2828 committed Feb 12, 2025
1 parent df9ca19 commit c267727
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 41 deletions.
1 change: 1 addition & 0 deletions t4_devkit/helper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations

from .rendering import * # noqa
from .timeseries import * # noqa
115 changes: 115 additions & 0 deletions t4_devkit/helper/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from t4_devkit.common.timestamp import us2sec

if TYPE_CHECKING:
from t4_devkit import Tier4
from t4_devkit.schema import ObjectAnn, Sample, SampleAnnotation, SampleData


__all__ = ["TimeseriesHelper"]


class TimeseriesHelper:
"""Help `Tier4` class with timeseries relevant operations."""

def __init__(self, t4: Tier4) -> None:
"""Construct a new object.
Args:
t4 (Tier4): `Tier4` instance.
"""
self._t4 = t4

self._sample_and_instance_to_ann3d: dict[tuple[str, str], str] = {
(ann.sample_token, ann.instance_token): ann.token for ann in self._t4.sample_annotation
}
self._sample_data_and_instance_to_ann2d: dict[tuple[str, str], str] = {
(ann.sample_data_token, ann.instance_token): ann.token for ann in self._t4.object_ann
}

def get_sample_annotations_util(
self,
instance_token: str,
sample_token: str,
seconds: float,
) -> list[SampleAnnotation]:
"""Return a list of sample annotations until the specified seconds.
If `seconds>=0` explores future, otherwise past.
Args:
instance_token (str): Instance token of any sample annotations.
sample_token (str): Start sample token.
seconds (float): Time seconds until. If `>=0` explore future, otherwise past.
Returns:
List of sample annotation records of the specified instance.
"""
start_sample: Sample = self._t4.get("sample", sample_token)

outputs: list[SampleAnnotation] = []
is_successor = seconds >= 0
current_sample_token = start_sample.next if is_successor else start_sample.prev
while current_sample_token != "":
current_sample: Sample = self._t4.get("sample", current_sample_token)

if abs(us2sec(current_sample.timestamp - start_sample.timestamp)) > abs(seconds):
break

ann_token = self._sample_and_instance_to_ann3d.get(
(current_sample_token, instance_token)
)
if ann_token is not None:
outputs.append(self._t4.get("sample_annotation", ann_token))

current_sample_token = current_sample.next if is_successor else current_sample.prev

return outputs

def get_object_anns_until(
self,
instance_token: str,
sample_data_token: str,
seconds: float,
) -> list[ObjectAnn]:
"""Return a list of object anns until the specified seconds.
If `seconds>=0` explores future, otherwise past.
Args:
instance_token (str): Instance token of any object anns.
sample_data_token (str): Start sample data token.
seconds (float): Time seconds until. If `>=0` explore future, otherwise past.
Returns:
List of object annotation records of the specified instance.
"""
start_sample_data: SampleData = self._t4.get("sample_data", sample_data_token)

outputs: list[ObjectAnn] = []
is_successor = seconds >= 0
current_sample_data_token = (
start_sample_data.next if is_successor else start_sample_data.prev
)
while current_sample_data_token != "":
current_sample_data: SampleData = self._t4.get("sample_data", current_sample_data_token)

if abs(us2sec(current_sample_data.timestamp - start_sample_data.timestamp)) > abs(
seconds
):
break

ann_token = self._sample_data_and_instance_to_ann2d.get(
(current_sample_data_token, instance_token)
)
if ann_token is not None:
outputs.append(self._t4.get("object_ann", ann_token))

current_sample_data_token = (
current_sample_data.next if is_successor else current_sample_data.prev
)

return outputs
44 changes: 3 additions & 41 deletions t4_devkit/tier4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pyquaternion import Quaternion

from t4_devkit.common.geometry import is_box_in_image, view_points
from t4_devkit.common.timestamp import us2sec
from t4_devkit.dataclass import (
Box2D,
Box3D,
Expand All @@ -21,7 +20,7 @@
Shape,
ShapeType,
)
from t4_devkit.helper import RenderingHelper
from t4_devkit.helper import RenderingHelper, TimeseriesHelper
from t4_devkit.schema import SchemaName, SensorModality, VisibilityLevel, build_schema

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,6 +134,7 @@ def __init__(self, version: str, data_root: str, verbose: bool = True) -> None:
print(f"Done loading in {elapsed_time:.3f} seconds.\n======")

# initialize helpers after finishing construction of Tier4
self._timeseries_helper = TimeseriesHelper(self)
self._rendering_helper = RenderingHelper(self)

def __load_table__(self, schema: SchemaName) -> list[SchemaTable]:
Expand Down Expand Up @@ -219,15 +219,10 @@ def __make_reverse_index__(self, verbose: bool) -> None:
sample_record: Sample = self.get("sample", record.sample_token)
sample_record.data[record.channel] = record.token

self._sample_and_instance_to_ann3d: dict[tuple[str, str], str] = {}
for ann_record in self.sample_annotation:
sample_record: Sample = self.get("sample", ann_record.sample_token)
sample_record.ann_3ds.append(ann_record.token)

self._sample_and_instance_to_ann3d[(sample_record.token, ann_record.instance_token)] = (
ann_record.token
)

for ann_record in self.object_ann:
sd_record: SampleData = self.get("sample_data", ann_record.sample_data_token)
sample_record: Sample = self.get("sample", sd_record.sample_token)
Expand Down Expand Up @@ -451,7 +446,7 @@ def get_box3d(self, sample_annotation_token: str, *, future_seconds: float = 0.0

if future_seconds > 0.0:
# NOTE: Future trajectory is map coordinate frame
anns: list[SampleAnnotation] = self.get_sample_annotations_until(
anns: list[SampleAnnotation] = self._timeseries_helper.get_sample_annotations_util(
ann.instance_token, ann.sample_token, future_seconds
)
if len(anns) == 0:
Expand All @@ -461,39 +456,6 @@ def get_box3d(self, sample_annotation_token: str, *, future_seconds: float = 0.0
else:
return box

def get_sample_annotations_until(
self,
instance_token: str,
sample_token: str,
seconds: float,
) -> list[SampleAnnotation]:
"""Return a list of sample annotations until the specified seconds.
Args:
instance_token (str): Instance token.
sample_token (str): Start sample token.
seconds (float): Time seconds until.
Returns:
list[SampleAnnotation]: List of sample annotation records.
"""
outputs = []
start_sample: Sample = self.get("sample", sample_token)

current_sample = start_sample
while current_sample.next != "":
next_sample: Sample = self.get("sample", current_sample.next)
if us2sec(next_sample.timestamp - start_sample.timestamp) > seconds:
break

ann_token = self._sample_and_instance_to_ann3d.get((next_sample.token, instance_token))
if ann_token is not None:
outputs.append(self.get("sample_annotation", ann_token))

current_sample = next_sample

return outputs

def get_box2d(self, object_ann_token: str) -> Box2D:
"""Return a Box2D class from a `object_ann` record.
Expand Down

0 comments on commit c267727

Please sign in to comment.