Skip to content

Commit

Permalink
PLY load normals
Browse files Browse the repository at this point in the history
Summary: Add ability to load normals when they are present in a PLY file.

Reviewed By: nikhilaravi

Differential Revision: D26458971

fbshipit-source-id: 658270b611f7624eab4f5f62ff438038e1d25723
  • Loading branch information
bottler authored and facebook-github-bot committed May 4, 2021
1 parent b314bee commit 6fa66f5
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 21 deletions.
75 changes: 55 additions & 20 deletions pytorch3d/io/ply_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:

def _get_verts_column_indices(
vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]], float]:
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
"""
Get the columns of verts and verts_colors in the vertex
Get the columns of verts, verts_colors, and verts_normals in the vertex
element of a parsed ply file, together with a color scale factor.
When the colors are in byte format, they are scaled from 0..255 to [0,1].
Otherwise they are not scaled.
Expand All @@ -793,11 +793,14 @@ def _get_verts_column_indices(
property double x
property double y
property double z
property double nx
property double ny
property double nz
property uchar red
property uchar green
property uchar blue
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
Args:
vertex_head: as returned from load_ply_raw.
Expand All @@ -807,9 +810,12 @@ def _get_verts_column_indices(
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
"""
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
normal_idxs: List[Optional[int]] = [None, None, None]
for i, prop in enumerate(vertex_head.properties):
if prop.list_size_type is not None:
raise ValueError("Invalid vertices in file: did not expect list.")
Expand All @@ -819,6 +825,9 @@ def _get_verts_column_indices(
for j, name in enumerate(["red", "green", "blue"]):
if prop.name == name:
color_idxs[j] = i
for j, name in enumerate(["nx", "ny", "nz"]):
if prop.name == name:
normal_idxs[j] = i
if None in point_idxs:
raise ValueError("Invalid vertices in file.")
color_scale = 1.0
Expand All @@ -831,21 +840,23 @@ def _get_verts_column_indices(
point_idxs,
None if None in color_idxs else cast(List[int], color_idxs),
color_scale,
None if None in normal_idxs else cast(List[int], normal_idxs),
)


def _get_verts(
header: _PlyHeader, elements: dict
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Get the vertex locations and colors from a parsed ply file.
Get the vertex locations, colors and normals from a parsed ply file.
Args:
header, elements: as returned from load_ply_raw.
Returns:
verts: FloatTensor of shape (V, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
"""

vertex = elements.get("vertex", None)
Expand All @@ -854,14 +865,16 @@ def _get_verts(
if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
vertex_head
)

# Case of no vertices
if vertex_head.count == 0:
verts = torch.zeros((0, 3), dtype=torch.float32)
if color_idxs is None:
return verts, None
return verts, torch.zeros((0, 3), dtype=torch.float32)
return verts, None, None
return verts, torch.zeros((0, 3), dtype=torch.float32), None

# Simple case where the only data is the vertices themselves
if (
Expand All @@ -870,9 +883,10 @@ def _get_verts(
and vertex[0].ndim == 2
and vertex[0].shape[1] == 3
):
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None

vertex_colors = None
vertex_normals = None

if len(vertex) == 1:
# This is the case where the whole vertex element has one type,
Expand All @@ -882,6 +896,10 @@ def _get_verts(
vertex_colors = color_scale * torch.tensor(
vertex[0][:, color_idxs], dtype=torch.float32
)
if normal_idxs is not None:
vertex_normals = torch.tensor(
vertex[0][:, normal_idxs], dtype=torch.float32
)
else:
# The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type.
Expand Down Expand Up @@ -913,13 +931,22 @@ def _get_verts(
partnum, col = prop_to_partnum_col[color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
vertex_colors *= color_scale
if normal_idxs is not None:
vertex_normals = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for axis in range(3):
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]

return verts, vertex_colors
return verts, vertex_colors, vertex_normals


def _load_ply(
f, *, path_manager: PathManager
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
]:
"""
Load the data from a .ply file.
Expand All @@ -935,10 +962,11 @@ def _load_ply(
verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
"""
header, elements = _load_ply_raw(f, path_manager=path_manager)

verts, vertex_colors = _get_verts(header, elements)
verts, vertex_colors, vertex_normals = _get_verts(header, elements)

face = elements.get("face", None)
if face is not None:
Expand Down Expand Up @@ -976,7 +1004,7 @@ def _load_ply(
if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0])

return verts, faces, vertex_colors
return verts, faces, vertex_colors, vertex_normals


def load_ply(
Expand Down Expand Up @@ -1031,7 +1059,7 @@ def load_ply(

if path_manager is None:
path_manager = PathManager()
verts, faces, _ = _load_ply(f, path_manager=path_manager)
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)

Expand Down Expand Up @@ -1211,18 +1239,23 @@ def read(
if not endswith(path, self.known_suffixes):
return None

verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
verts, faces, verts_colors, verts_normals = _load_ply(
f=path, path_manager=path_manager
)
if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)

textures = None
texture = None
if include_textures and verts_colors is not None:
textures = TexturesVertex([verts_colors.to(device)])
texture = TexturesVertex([verts_colors.to(device)])

if verts_normals is not None:
verts_normals = [verts_normals]
mesh = Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=textures,
textures=texture,
verts_normals=verts_normals,
)
return mesh

Expand Down Expand Up @@ -1286,12 +1319,14 @@ def read(
if not endswith(path, self.known_suffixes):
return None

verts, faces, features = _load_ply(f=path, path_manager=path_manager)
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
verts = verts.to(device)
if features is not None:
features = [features.to(device)]
if normals is not None:
normals = [normals.to(device)]

pointcloud = Pointclouds(points=[verts], features=features)
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
return pointcloud

def save(
Expand Down
51 changes: 50 additions & 1 deletion tests/test_io_ply.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,18 @@ def test_save_load_meshes(self):
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
)
vert_colors = torch.rand_like(verts)
texture = TexturesVertex(verts_features=[vert_colors])

for do_textures in itertools.product([True, False]):
for do_textures, do_normals in itertools.product([True, False], [True, False]):
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture if do_textures else None,
verts_normals=[normals] if do_normals else None,
)
device = torch.device("cuda:0")

Expand All @@ -236,12 +240,57 @@ def test_save_load_meshes(self):
mesh2 = mesh2.cpu()
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
if do_normals:
self.assertTrue(mesh.has_verts_normals())
self.assertTrue(mesh2.has_verts_normals())
self.assertClose(
mesh2.verts_normals_padded(), mesh.verts_normals_padded()
)
else:
self.assertFalse(mesh.has_verts_normals())
self.assertFalse(mesh2.has_verts_normals())
self.assertFalse(torch.allclose(mesh2.verts_normals_padded(), normals))
if do_textures:
self.assertIsInstance(mesh2.textures, TexturesVertex)
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
else:
self.assertIsNone(mesh2.textures)

def test_save_load_with_normals(self):
points = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
)
features = torch.rand_like(points)

for do_features, do_normals in itertools.product([True, False], [True, False]):
cloud = Pointclouds(
points=[points],
features=[features] if do_features else None,
normals=[normals] if do_normals else None,
)
device = torch.device("cuda:0")

io = IO()
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
io.save_pointcloud(cloud.cuda(), f.name)
f.flush()
cloud2 = io.load_pointcloud(f.name, device=device)
self.assertEqual(cloud2.device, device)
cloud2 = cloud2.cpu()
self.assertClose(cloud2.points_padded(), cloud.points_padded())
if do_normals:
self.assertClose(cloud2.normals_padded(), cloud.normals_padded())
else:
self.assertIsNone(cloud.normals_padded())
self.assertIsNone(cloud2.normals_padded())
if do_features:
self.assertClose(cloud2.features_packed(), features)
else:
self.assertIsNone(cloud2.features_packed())

def test_save_ply_invalid_shapes(self):
# Invalid vertices shape
with self.assertRaises(ValueError) as error:
Expand Down

0 comments on commit 6fa66f5

Please sign in to comment.