Skip to content

Commit

Permalink
chore(docs): update tutorial to match paper content (#147)
Browse files Browse the repository at this point in the history
* chore(docs): update tutorial to match paper content

* fix(docs): typos

* chore(ci): fmt
  • Loading branch information
jeertmans authored Oct 18, 2024
1 parent 3e44c6f commit 6306ade
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 259 deletions.
40 changes: 20 additions & 20 deletions differt/src/differt/geometry/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@jax.jit
@jaxtyped(typechecker=typechecker)
def _cluster_ids(array: Shaped[Array, "batch n"]) -> Int[Array, " batch"]:
def _cell_ids(array: Shaped[Array, "batch n"]) -> Int[Array, " batch"]:
@jaxtyped(typechecker=typechecker)
def scan_fun(
indices: Int[Array, " batch"],
Expand All @@ -40,15 +40,15 @@ def scan_fun(

@jax.jit
@jaxtyped(typechecker=typechecker)
def merge_cluster_ids(
cluster_ids_a: Int[Array, " *batch"],
cluster_ids_b: Int[Array, " *batch"],
def merge_cell_ids(
cell_ids_a: Int[Array, " *batch"],
cell_ids_b: Int[Array, " *batch"],
) -> Int[Array, " *batch"]:
"""
Merge two arrays of cluster indices as returned by :meth:`Paths.multipath_clusters`.
Merge two arrays of cell indices as returned by :meth:`Paths.multipath_cells`.
Let the returned array be ``cluster_ids``,
then ``cluster_ids[i] == cluster_ids[j]`` for all ``i``,
Let the returned array be ``cell_ids``,
then ``cell_ids[i] == cell_ids[j]`` for all ``i``,
``j`` indices if
``(groups_a[i], groups_b[i]) == (groups_a[j], groups_b[j])``,
granted that arrays have been reshaped to uni-dimensional
Expand All @@ -60,15 +60,15 @@ def merge_cluster_ids(
do with the ones used in individual arrays.
Args:
cluster_ids_a: The first array of cluster indices.
cluster_ids_b: The second array of cluster indices.
cell_ids_a: The first array of cell indices.
cell_ids_b: The second array of cell indices.
Returns:
The new array group indices.
"""
batch = cluster_ids_a.shape
return _cluster_ids(
jnp.stack((cluster_ids_a, cluster_ids_b), axis=-1).reshape(-1, 2),
batch = cell_ids_a.shape
return _cell_ids(
jnp.stack((cell_ids_a, cell_ids_b), axis=-1).reshape(-1, 2),
).reshape(batch)


Expand Down Expand Up @@ -175,15 +175,15 @@ def masked_objects(

@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def multipath_clusters(
def multipath_cells(
self,
axis: int = -1,
) -> Int[Array, " *partial_batch"]:
"""
Return an array of same multipath cluster indices.
Return an array of same multipath cell indices.
Let the returned array be ``cluster_ids``,
then ``cluster_ids[i] == cluster_ids[j]`` for all ``i``,
Let the returned array be ``cell_ids``,
then ``cell_ids[i] == cell_ids[j]`` for all ``i``,
``j`` indices if ``self.mask[i, :] == self.mask[j, :]``,
granted that each array has been reshaped to a two-dimensional
array and that ``axis`` is the last dimension. Of course, this
Expand Down Expand Up @@ -218,13 +218,13 @@ def multipath_clusters(
ValueError: If :attr:`mask` is None.
"""
if self.mask is None:
msg = "Cannot create multiplath clusters from non-existing mask!"
msg = "Cannot create multiplath cells from non-existing mask!"
raise ValueError(msg)

mask = jnp.moveaxis(self.mask, axis, -1)
*partial_batch, last_axis = mask.shape

return _cluster_ids(mask.reshape(-1, last_axis)).reshape(partial_batch)
return _cell_ids(mask.reshape(-1, last_axis)).reshape(partial_batch)

@jax.jit
@jaxtyped(typechecker=typechecker)
Expand All @@ -236,7 +236,7 @@ def group_by_objects(self) -> Int[Array, " *batch"]:
undergo the same types of interactions.
Internally, it uses the same logic as
:meth:`multipath_clusters`, but applied to object indices
:meth:`multipath_cells`, but applied to object indices
rather than on mask.
Returns:
Expand Down Expand Up @@ -277,7 +277,7 @@ def group_by_objects(self) -> Int[Array, " *batch"]:
*batch, path_length = self.objects.shape

objects = self.objects.reshape((-1, path_length))
return _cluster_ids(objects).reshape(batch)
return _cell_ids(objects).reshape(batch)

def __iter__(self) -> Iterator[Self]:
"""Return an iterator over masked paths.
Expand Down
28 changes: 14 additions & 14 deletions differt/src/differt/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,44 +439,44 @@ def assemble_paths(

@jax.jit
@jaxtyped(typechecker=typechecker)
def min_distance_between_clusters(
cluster_vertices: Float[Array, "*batch 3"],
cluster_ids: Int[Array, "*batch"],
def min_distance_between_cells(
cell_vertices: Float[Array, "*batch 3"],
cell_ids: Int[Array, "*batch"],
) -> Float[Array, "*batch"]:
"""
Compute the minimal (Euclidean) distance between vertices in different clusters.
Compute the minimal (Euclidean) distance between vertices in different cells.
For every vertex, the minimum distance to another vertex that is not is the same
cluster is computed.
cell is computed.
Args:
cluster_vertices: The array of vertex coordinates.
cluster_ids: The array of corresponding cluster indices.
cell_vertices: The array of vertex coordinates.
cell_ids: The array of corresponding cell indices.
Returns:
The array of minimal distances.
"""

@jaxtyped(typechecker=typechecker)
def scan_fun(
_: None, vertex_and_cluster_id: tuple[Float[Array, "3"], Int[Array, " "]]
_: None, vertex_and_cell_id: tuple[Float[Array, "3"], Int[Array, " "]]
) -> tuple[None, Float[Array, " "]]:
vertex, cluster_id = vertex_and_cluster_id
vertex, cell_id = vertex_and_cell_id
min_dist = jnp.min(
jnp.linalg.norm(
cluster_vertices - vertex,
cell_vertices - vertex,
axis=-1,
),
initial=jnp.inf,
where=(cluster_id != cluster_ids),
where=(cell_id != cell_ids),
)
return None, min_dist

return jax.lax.scan(
scan_fun,
init=None,
xs=(
cluster_vertices.reshape(-1, 3),
cluster_ids.reshape(-1),
cell_vertices.reshape(-1, 3),
cell_ids.reshape(-1),
),
)[1].reshape(cluster_ids.shape)
)[1].reshape(cell_ids.shape)
Loading

0 comments on commit 6306ade

Please sign in to comment.