Skip to content

Commit

Permalink
fix(docs): typos and phrasing in docs/source/jax_and_jaxtyping.md (#72
Browse files Browse the repository at this point in the history
)

* Typos and phrasing in jax_and_jaxtyping.md

* Update jax_and_jaxtyping.md
  • Loading branch information
idoby authored Jun 27, 2024
1 parent 59822cd commit ab6f9fb
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions docs/source/jax_and_jaxtyping.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ and accelerated algebra (*XLA*) in the same Python package.

Because we use JAX arrays everywhere in the code,
you can evaluate the gradient of nearly any function
by wrapping it inside {func}`jax.grad`. E.g.:
by wrapping it inside {func}`jax.grad`, e.g.:

```python
import jax
Expand Down Expand Up @@ -40,21 +40,21 @@ for more details.
## Understanding JAX Arrays

The main advantage of JAX over other *autodiff* libraries
(like PyTorch or TensforFlow) is that you can pretty much use
it as a drop-in replacement of [NumPy](https://numpy.org/),
because almost all function from NumPy is present in
{mod}`jax.numpy`.
(like PyTorch or TensorFlow) is that you can pretty much use
it as a drop-in replacement for [NumPy](https://numpy.org/),
because JAX contains a NumPy-compatible implementation of almost
all NumPy functions in {mod}`jax.numpy`.

JAX also comes with the concept of
[PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html),
and JAX arrays are PyTrees.

The only thing your **really need to know** is that PyTrees
The only thing you **really need to know** is that PyTrees
are **immutable**. Hence, for compatibility with JAX's philosophy,
DiffeRT2d's object are also **immutable PyTrees**.
DiffeRT2d's objects are also **immutable PyTrees**.

As such, every methods on those objects are likely to return
a **new object instance**. E.g.:
As such, every method these objects expose is likely to return
a **new object instance**, e.g.:

```python
from differt.geometry import Point
Expand All @@ -68,9 +68,8 @@ scene = scene.with_transmitters(tx=Point()) # Do
```

Finally, because our objects are PyTrees,
you can use {func}`equinox.tree_at` to *mutate*[^1] a PyTree.

[^1]: Again, the mutation will actually return a new object.
you can use {func}`equinox.tree_at` to *transform* a PyTree,
i.e., returning a new object after the transformation.

## Type-checking JAX Arrays

Expand All @@ -79,14 +78,13 @@ is usually quite hard to enforce that a specific input array
must be, e.g., two-dimensional.

With [`jaxtyping`](https://docs.kidger.site/jaxtyping/),
we provide both meaningful type hints we
**all** our API, but also strong type checking at runtime[^2].

[^2]: Thanks to just-in-time compilation, the overhead of runtime type
checking is minimal and only performed once for each compiled version
of a given function.
we provide both meaningful type hints for every method in our API,
as well as enforce strong type and dimensionality checking at runtime.
Thanks to just-in-time compilation, the overhead of runtime type
checking is minimal and only performed once for each compiled version
of a given function.

E.g., the following code indicates that
For example, the following code illustrates that
function `my_function` takes a 2D array of floating-point values, `x`,
and returns a 1D array whose length matches the first axis of
the input array `x`.
Expand Down

0 comments on commit ab6f9fb

Please sign in to comment.