Skip to content

Commit

Permalink
feat(lib): preserve objects when samping and num_objects (#140)
Browse files Browse the repository at this point in the history
* feat(lib): preserve objects when samping and `num_objects`

* chore(lint): silence pyright
  • Loading branch information
jeertmans authored Oct 11, 2024
1 parent 7c35789 commit 4b4ba80
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
63 changes: 60 additions & 3 deletions differt/src/differt/geometry/triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __check_init__(self) -> None: # noqa: D105,PLW3201
@jaxtyped(
typechecker=None
) # typing.Self is (currently) not compatible with jaxtyping and beartype
def __getitem__(self, key: slice | Int[ArrayLike, "*batch"]) -> Self:
def __getitem__(self, key: slice | Int[ArrayLike, " n"]) -> Self:
"""Return a copy of this mesh, taking only specific triangles.
Warning:
Expand Down Expand Up @@ -188,6 +188,15 @@ def num_quads(self) -> int:

return self.triangles.shape[0] // 2

@property
def num_objects(self) -> int:
"""The number of objects.
This is a convenient alias to :attr:`num_quads` :attr:`assume_quads` is :data:`True`
else :attr:`num_triangles`.
"""
return self.num_quads if self.assume_quads else self.num_triangles

@property
@jax.jit
@jaxtyped(typechecker=typechecker)
Expand Down Expand Up @@ -439,24 +448,72 @@ def sample(
self,
size: int,
replace: bool = False,
preserve: bool = False,
*,
key: PRNGKeyArray,
) -> Self:
"""
Generate a new mesh by randomly sampling triangles from this geometry.
Warning:
If :attr:`assume_quads` is :data:`True`, then quadrilaterals are
sampled.
Args:
size: The size of the sample, i.e., the number of triangles.
replace: Whether to sample with or without replacement.
preserve: Whether to preserve :attr:`object_bounds`, otherwise
it is discarded.
Object bounds are re-generated by sorting the randomly generated samples,
which takes additional time.
Setting this to :data:`True` has no effect if :attr:`object_bounds`
is :data:`None`.
key: The :func:`jax.random.key` to be used.
Returns:
A new random mesh.
"""
indices = jax.random.choice(
key,
self.num_triangles,
self.num_objects,
shape=(size,),
replace=replace,
)
return self[indices]

if preserve and self.object_bounds is not None:
indices = jnp.sort(indices)
object_bounds = jnp.stack(
(
jnp.searchsorted(indices, self.object_bounds[:, 0]),
jnp.searchsorted(indices, self.object_bounds[:, 1]),
),
axis=-1,
)
else:
object_bounds = None

if self.assume_quads:
indices = jnp.stack((indices, indices + 1), axis=-1).reshape(-1)

return eqx.tree_at(
lambda m: (
m.vertices,
m.triangles,
m.face_colors,
m.face_materials,
m.object_bounds,
),
self,
(
self.vertices,
self.triangles[indices, :],
self.face_colors[indices, :] if self.face_colors is not None else None,
self.face_materials[indices]
if self.face_materials is not None
else None,
object_bounds,
),
is_leaf=lambda x: x is None,
)
27 changes: 27 additions & 0 deletions differt/tests/geometry/test_triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def test_num_quads(self, two_buildings_mesh: TriangleMesh) -> None:
# 'tree_at' bypasses '__check_init__', so this will not raise an error
_ = eqx.tree_at(lambda m: m.assume_quads, non_quad_mesh, replace=True)

def test_num_objects(self, two_buildings_mesh: TriangleMesh) -> None:
assert two_buildings_mesh.num_objects == 24
assert two_buildings_mesh.set_assume_quads().num_objects == 12

def test_get_item(self, two_buildings_mesh: TriangleMesh) -> None:
got = two_buildings_mesh[:]

Expand Down Expand Up @@ -302,3 +306,26 @@ def test_normals(self, two_buildings_mesh: TriangleMesh) -> None:

def test_plot(self, sphere_mesh: TriangleMesh) -> None:
sphere_mesh.plot()

def test_sample(self, two_buildings_mesh: TriangleMesh, key: PRNGKeyArray) -> None:
assert two_buildings_mesh.sample(10, key=key).num_triangles == 10

with pytest.raises(
ValueError, match="Cannot take a larger sample than population"
):
assert two_buildings_mesh.sample(30, key=key)

assert two_buildings_mesh.sample(30, replace=True, key=key).num_triangles == 30

assert two_buildings_mesh.set_assume_quads().sample(5, key=key).num_quads == 5

two_buildings_mesh = eqx.tree_at(
lambda m: m.object_bounds,
two_buildings_mesh,
jnp.array([[0, 12], [12, 24]]),
is_leaf=lambda x: x is None,
)

assert two_buildings_mesh.sample(
13, key=key, preserve=True
).object_bounds.shape == (2, 2) # type: ignore[reportOptionalMemberAccess]

0 comments on commit 4b4ba80

Please sign in to comment.