Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Autoware compatibility #1

Merged
merged 8 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ARG PYTORCH="1.9.0"
ARG CUDA="11.1"
ARG PYTORCH="1.13.1"
ARG CUDA="11.6"
ARG CUDNN="8"

FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
Expand All @@ -9,14 +9,6 @@ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
FORCE_CUDA="1"

# Avoid Public GPG key error
# https://github.com/NVIDIA/nvidia-docker/issues/1631
RUN rm /etc/apt/sources.list.d/cuda.list \
&& rm /etc/apt/sources.list.d/nvidia-ml.list \
&& apt-key del 7fa2af80 \
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub

# (Optional, use Mirror to speed up downloads)
# RUN sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list && \
# pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
Expand All @@ -29,11 +21,11 @@ RUN apt-get update \

# Install MMEngine, MMCV and MMDetection
RUN pip install openmim && \
mim install "mmengine" "mmcv>=2.0.0rc4" "mmdet>=3.0.0"
mim install "mmengine" "mmcv>=2.0.0rc4" "mmdet>=3.0.0rc5, <3.3.0"

# Install MMDetection3D
RUN conda clean --all \
&& git clone https://github.com/open-mmlab/mmdetection3d.git -b dev-1.x /mmdetection3d \
&& git clone https://github.com/autowarefoundation/mmdetection3d.git -b main /mmdetection3d \
&& cd /mmdetection3d \
&& pip install --no-cache-dir -e .

Expand Down
11 changes: 11 additions & 0 deletions projects/AutowareCenterPoint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Introduction

The **[mmdetection3d](https://github.com/open-mmlab/mmdetection3d)** repository includes an additional voxel encoder
feature for the CenterPoint 3D object detection model, known as voxel center z,
not originally used in the **[main implementation](https://github.com/tianweiy/CenterPoint)**,
Autoware maintains consistency with the input size of the original implementation. Consequently,
to ensure integration with Autoware's lidar centerpoint package, we have forked the original repository and made
the requisite code modifications.

To train custom CenterPoint models and convert them into ONNX format for deployment in Autoware, please refer to the instructions provided in the README.md file included with
Autoware's **[lidar_centerpoint](https://autowarefoundation.github.io/autoware.universe/main/perception/lidar_centerpoint/)** package. These instructions will provide a step-by-step guide for training the CenterPoint model.
3 changes: 3 additions & 0 deletions projects/AutowareCenterPoint/centerpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pillar_encoder_autoware import PillarFeatureNetAutoware

__all__ = ['PillarFeatureNetAutoware']
167 changes: 167 additions & 0 deletions projects/AutowareCenterPoint/centerpoint/pillar_encoder_autoware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Optional, Tuple

import torch
from torch import Tensor, nn

from mmdet3d.models.voxel_encoders.utils import (PFNLayer,
get_paddings_indicator)
from mmdet3d.registry import MODELS


@MODELS.register_module()
class PillarFeatureNetAutoware(nn.Module):
"""Pillar Feature Net.

The network prepares the pillar features and performs forward pass
through PFNLayers.

Args:
in_channels (int, optional): Number of input features,
either x, y, z or x, y, z, r. Defaults to 4.
feat_channels (tuple, optional): Number of features in each of the
N PFNLayers. Defaults to (64, ).
with_distance (bool, optional): Whether to include Euclidean distance
to points. Defaults to False.
with_cluster_center (bool, optional): [description]. Defaults to True.
with_voxel_center (bool, optional): [description]. Defaults to True.
voxel_size (tuple[float], optional): Size of voxels, only utilize x
and y size. Defaults to (0.2, 0.2, 4).
point_cloud_range (tuple[float], optional): Point cloud range, only
utilizes x and y min. Defaults to (0, -40, -3, 70.4, 40, 1).
norm_cfg ([type], optional): [description].
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
legacy (bool, optional): Whether to use the new behavior or
the original behavior. Defaults to True.
"""

def __init__(
self,
in_channels: Optional[int] = 4,
feat_channels: Optional[tuple] = (64, ),
with_distance: Optional[bool] = False,
with_cluster_center: Optional[bool] = True,
with_voxel_center: Optional[bool] = True,
voxel_size: Optional[Tuple[float]] = (0.2, 0.2, 4),
point_cloud_range: Optional[Tuple[float]] = (0, -40, -3, 70.4, 40, 1),
norm_cfg: Optional[dict] = dict(type='BN1d', eps=1e-3, momentum=0.01),
mode: Optional[str] = 'max',
legacy: Optional[bool] = True,
use_voxel_center_z: Optional[bool] = True,
):
super(PillarFeatureNetAutoware, self).__init__()
assert len(feat_channels) > 0
self.legacy = legacy
self.use_voxel_center_z = use_voxel_center_z
if with_cluster_center:
in_channels += 3
if with_voxel_center:
in_channels += 2
if self.use_voxel_center_z:
in_channels += 1
if with_distance:
in_channels += 1
self._with_distance = with_distance
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
# Create PillarFeatureNet layers
self.in_channels = in_channels
feat_channels = [in_channels] + list(feat_channels)
pfn_layers = []
for i in range(len(feat_channels) - 1):
in_filters = feat_channels[i]
out_filters = feat_channels[i + 1]
if i < len(feat_channels) - 2:
last_layer = False
else:
last_layer = True
pfn_layers.append(
PFNLayer(
in_filters,
out_filters,
norm_cfg=norm_cfg,
last_layer=last_layer,
mode=mode))
self.pfn_layers = nn.ModuleList(pfn_layers)

# Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0]
self.vy = voxel_size[1]
self.vz = voxel_size[2]
self.x_offset = self.vx / 2 + point_cloud_range[0]
self.y_offset = self.vy / 2 + point_cloud_range[1]
self.z_offset = self.vz / 2 + point_cloud_range[2]
self.point_cloud_range = point_cloud_range

def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaancolak
This function is very long, and it is really difficult for reviewer to check. So could you cover this function by test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this function is hard copied from the original repository, it just contains a small adjustment to make it compatible with Autoware's centerpoint TensorRT implementation. Autoware's implementation doesn't contain voxel center z as an encoder input but the original mmdetection3d implementation contains it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaancolak

Thank you for the quick response! This repository is open-source and will receive PRs from various contributors. To ease the burden on reviewers, adding tests seems to be the simplest solution for the future. What do you think? It's also fine to include tests that are already present in the original code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Shin-kyoto -san, thank you.

IMO, we can add tests for customized operations like "autoware_voxel_encoder", and "T4Dataset" operations, etc. , original parts of the implementations have already been checked with the test under mmdet main repository.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaancolak

I also think that we can add tests for customized operations. PillarFeatureNetAutoware is also customized class and test will be really helpful when we review the PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Shin-kyoto -san, I added tests for the T4 dataset and autoware_voxel_encoder, PillarFeatureNetAutoware base class of autoware_voxel_encoder.

You can basically run with

cd mmdetection3d/ && pytest -s projects/AutowareCenterPoint/tests/

Relevant files:
https://github.com/autowarefoundation/mmdetection3d/pull/1/files#diff-0b5c1ac645064857b8ac247c5d193291b74c6130bfdfa4c22aed90d2bd0df58a

*args, **kwargs) -> Tensor:
"""Forward function.

Args:
features (torch.Tensor): Point features or raw points in shape
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel.

Returns:
torch.Tensor: Features of pillars.
"""
features_ls = [features]
# Find distance of x, y, and z from cluster center
if self._with_cluster_center:
points_mean = features[:, :, :3].sum(
dim=1, keepdim=True) / num_points.type_as(features).view(
-1, 1, 1)
f_cluster = features[:, :, :3] - points_mean
features_ls.append(f_cluster)

# Find distance of x, y, and z from pillar center
dtype = features.dtype
if self._with_voxel_center:
center_feature_size = 3 if self.use_voxel_center_z else 2
if not self.legacy:
f_center = torch.zeros_like(
features[:, :, :center_feature_size])
f_center[:, :, 0] = features[:, :, 0] - (
coors[:, 3].to(dtype).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = features[:, :, 1] - (
coors[:, 2].to(dtype).unsqueeze(1) * self.vy +
self.y_offset)
if self.use_voxel_center_z:
f_center[:, :, 2] = features[:, :, 2] - (
coors[:, 1].to(dtype).unsqueeze(1) * self.vz +
self.z_offset)
else:
f_center = features[:, :, :center_feature_size]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
if self.use_voxel_center_z:
f_center[:, :, 2] = f_center[:, :, 2] - (
coors[:, 1].type_as(features).unsqueeze(1) * self.vz +
self.z_offset)
features_ls.append(f_center)

if self._with_distance:
points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
features_ls.append(points_dist)

# Combine together feature decorations
features = torch.cat(features_ls, dim=-1)
# The feature decorations were calculated without regard to whether
# pillar was empty. Need to ensure that
# empty pillars remain set to zeros.
voxel_count = features.shape[1]
mask = get_paddings_indicator(num_points, voxel_count, axis=0)
mask = torch.unsqueeze(mask, -1).type_as(features)
features *= mask

for pfn in self.pfn_layers:
features = pfn(features, num_points)

return features.squeeze(1)
Loading
Loading