Skip to content

Commit

Permalink
fix(docs): hacky ref to jaxtyping and equinox modules (#145)
Browse files Browse the repository at this point in the history
* fix(docs): hacky ref to jaxtyping and equinox modules

* fmt
  • Loading branch information
jeertmans authored Oct 14, 2024
1 parent 0b1824d commit ef9013e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
10 changes: 5 additions & 5 deletions differt-core/src/geometry/triangle_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,21 @@ impl TriangleMesh {

#[pymethods]
impl TriangleMesh {
/// ``Float[np.ndarray, 'num_vertices 3']``: The array of triangle vertices.
/// :class:`Float[np.ndarray, 'num_vertices 3']<jaxtyping.Float>`: The array of triangle vertices.
#[getter]
fn vertices<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f32>> {
let array = arr2(&self.vertices);
PyArray2::from_owned_array_bound(py, array)
}

/// ``Int[np.ndarray, 'num_triangles 3']``: The array of triangle indices.
/// :class:`Int[np.ndarray, 'num_triangles 3']<jaxtyping.Int>`: The array of triangle indices.
#[getter]
fn triangles<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<usize>> {
let array = arr2(&self.triangles);
PyArray2::from_owned_array_bound(py, array)
}

/// ``Float[np.ndarray, 'num_vertices 3']`` | :data:`None`: The array of face colors.
/// :class:`Float[np.ndarray, 'num_vertices 3']<jaxtyping.Float>` | :data:`None`: The array of face colors.
///
/// The array contains the face colors, as RGB triplets,
/// with a black color used as defaults (if some faces have a color).
Expand All @@ -147,7 +147,7 @@ impl TriangleMesh {
None
}

/// ``Int[np.ndarray, 'num_vertices']`` | :data:`None`: The array of face materials.
/// :class:`Int[np.ndarray, 'num_vertices']<jaxtyping.Int>` | :data:`None`: The array of face materials.
///
/// The array contains the material indices,
/// with a special placeholder value of ``-1``.
Expand All @@ -161,7 +161,7 @@ impl TriangleMesh {
None
}

/// ``Int[np.ndarray, 'num_objects 2']`` | :data:`None`: The array of object indices.
/// :class:`Int[np.ndarray, 'num_objects 2']<jaxtyping.Int>` | :data:`None`: The array of object indices.
///
/// If the present mesh contains multiple objects, usually as a result of
/// appending multiple meshes together, this array contain start end end
Expand Down
55 changes: 42 additions & 13 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datetime import date
from typing import Any

from docutils.nodes import Element, TextElement
from docutils import nodes
from sphinx.addnodes import pending_xref
from sphinx.application import Sphinx
from sphinx.environment import BuildEnvironment
Expand Down Expand Up @@ -66,15 +66,9 @@
("py:class", "differt.utils.TypeVarTuple"),
("py:class", "jax._src.typing.SupportsDType"),
("py:class", "ndarray"), # From ArrayLike
("py:mod", "equinox"),
("py:mod", "jaxtyping"),
("py:obj", "differt.utils._T"),
("py:obj", "differt.rt.utils._T"),
)
nitpick_ignore_regex = (
(r"py:.*", r"equinox\..*"),
(r"py:.*", r"jaxtyping\..*"),
)

# -- Intersphinx mapping

Expand Down Expand Up @@ -219,18 +213,53 @@ def fix_sionna_folder(_app: Sphinx, obj: Any, _bound_method: bool) -> None:


def fix_reference(
app: Sphinx, env: BuildEnvironment, node: pending_xref, contnode: TextElement
) -> Element | None:
app: Sphinx, env: BuildEnvironment, node: pending_xref, contnode: nodes.TextElement
) -> nodes.reference | None:
"""
Fix some intersphinx references that are broken.
"""
if node["refdomain"] == "py":
if node["reftarget"].startswith(
"equinox"
): # Sphinx fails to find them in the inventory
if node["reftarget"].endswith("Module"):
uri = (
"https://docs.kidger.site/equinox/api/module/module/#equinox.Module"
)
elif node["reftarget"].endswith("tree_at"):
uri = (
"https://docs.kidger.site/equinox/api/manipulation/#equinox.tree_at"
)
elif node["reftype"] == "mod":
uri = "https://docs.kidger.site/equinox/"
else:
return None

newnode = nodes.reference(
"", "", internal=False, refuri=uri, reftitle="(in equinox)"
)
newnode.append(contnode)

return newnode
if node["reftarget"].startswith(
"jaxtyping"
): # Sphinx fails to find them in the inventory
if node["reftype"] == "class":
uri = "https://docs.kidger.site/jaxtyping/api/array/#dtype"
elif node["reftype"] == "mod":
uri = "https://docs.kidger.site/jaxtyping/"
else:
return None

newnode = nodes.reference(
"", "", internal=False, refuri=uri, reftitle="(in jaxtyping)"
)
newnode.append(contnode)

return newnode
if node["reftarget"] == "plotly.graph_objs._figure.Figure":
node["reftarget"] = "plotly.graph_objects.Figure"
else:
return None

return missing_reference(app, env, node, contnode)
return missing_reference(app, env, node, contnode)

return None

Expand Down

0 comments on commit ef9013e

Please sign in to comment.