From 351944c7848d6f0aa52dd8ba1753cbf06d760dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 26 Jul 2023 15:17:43 +0200 Subject: [PATCH] chore(tests): much more tests, and docs fixes --- README.md | 2 +- differt2d/logic.py | 59 ++++++++------- differt2d/optimize.py | 9 ++- tests/conftest.py | 6 ++ tests/test_logic.py | 163 +++++++++++++++++++++++++++++++++++++++++- 5 files changed, 210 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index fec576a..0a0ba57 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ - Manim Slides Logo + DiffeRT2d Logo [![Documentation][documentation-badge]][documentation-url] diff --git a/differt2d/logic.py b/differt2d/logic.py index 2dc4ed9..4388ea1 100644 --- a/differt2d/logic.py +++ b/differt2d/logic.py @@ -38,7 +38,7 @@ from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import jax import jax.numpy as jnp @@ -52,7 +52,7 @@ help=("Enable approximation using some activation function."), ) -jit_approx = partial(jax.jit, inline=True, static_argnames=["approx"]) +jit_approx = partial(jax.jit, inline=True, static_argnames=["approx", "function"]) @contextmanager @@ -175,11 +175,11 @@ def disable_approx(disable: bool = True): yield -@partial(jax.jit, inline=True, static_argnames=("function",)) +@partial(jax.jit, inline=True, static_argnames=["function"]) def activation( x: Array, alpha: float = 1e2, - function: Literal["sigmoid", "hard_sigmoid"] = "sigmoid", + function: Literal["sigmoid", "hard_sigmoid"] = "hard_sigmoid", ) -> Array: r""" Element-wise function for approximating a discrete transition between 0 and 1, @@ -242,8 +242,8 @@ def f(x): raise ValueError(f"Unknown function '{function}'") -@jit_approx -def logical_or(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx"]) +def logical_or(x: Array, y: Array, approx: Optional[bool] = None) -> Array: """ Element-wise logical :python:`x or y`. @@ -260,12 +260,12 @@ def logical_or(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array: return jnp.maximum(x, y) if approx else jnp.logical_or(x, y) -@jit_approx -def logical_and(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx"]) +def logical_and(x: Array, y: Array, approx: Optional[bool] = None) -> Array: """ Element-wise logical :python:`x and y`. - Calls :func:`jax.numpy.maximum` if approximation is enabled, + Calls :func:`jax.numpy.minimum` if approximation is enabled, :func:`jax.numpy.logical_or` otherwise. :param x: The first input array. @@ -275,11 +275,11 @@ def logical_and(x: Array, y: Array, *, approx: Optional[bool] = None) -> Array: """ if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] - return jnp.multiply(x, y) if approx else jnp.logical_and(x, y) + return jnp.minimum(x, y) if approx else jnp.logical_and(x, y) -@jit_approx -def logical_not(x: Array, *, approx: Optional[bool] = None) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx"]) +def logical_not(x: Array, approx: Optional[bool] = None) -> Array: """ Element-wise logical :python:`not x`. @@ -296,13 +296,12 @@ def logical_not(x: Array, *, approx: Optional[bool] = None) -> Array: return jnp.subtract(1.0, x) if approx else jnp.logical_not(x) -@jit_approx +@partial(jax.jit, inline=True, static_argnames=["approx", "function"]) def greater( x: Array, y: Array, - *, approx: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> Array: """ Element-wise logical :python:`x > y`. @@ -314,6 +313,8 @@ def greater( :param x: The first input array. :param y: The second input array. :param approx: Whether approximation is enabled or not. + :param kwargs: + Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ if approx is None: @@ -321,9 +322,9 @@ def greater( return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y) -@jit_approx +@partial(jax.jit, inline=True, static_argnames=["approx", "function"]) def greater_equal( - x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs + x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any ) -> Array: """ Element-wise logical :python:`x >= y`. @@ -335,6 +336,8 @@ def greater_equal( :param x: The first input array. :param y: The second input array. :param approx: Whether approximation is enabled or not. + :param kwargs: + Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ if approx is None: @@ -344,8 +347,8 @@ def greater_equal( ) -@jit_approx -def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx", "function"]) +def less(x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any) -> Array: """ Element-wise logical :python:`x < y`. @@ -356,6 +359,8 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra :param x: The first input array. :param y: The second input array. :param approx: Whether approximation is enabled or not. + :param kwargs: + Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ if approx is None: @@ -363,8 +368,10 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y) -@jit_approx -def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx", "function"]) +def less_equal( + x: Array, y: Array, approx: Optional[bool] = None, **kwargs: Any +) -> Array: """ Element-wise logical :python:`x <= y`. @@ -375,6 +382,8 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) - :param x: The first input array. :param y: The second input array. :param approx: Whether approximation is enabled or not. + :param kwargs: + Keyword arguments to be passed to :func:`activation`. :return: Output array, with element-wise comparison. """ if approx is None: @@ -382,8 +391,8 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) - return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y) -@jit_approx -def is_true(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx"]) +def is_true(x: Array, tol: float = 0.05, approx: Optional[bool] = None) -> Array: """ Element-wise check if a given truth value can be considered to be true. @@ -401,8 +410,8 @@ def is_true(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Ar return jnp.greater(x, 1.0 - tol) if approx else jnp.asarray(x) -@jit_approx -def is_false(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array: +@partial(jax.jit, inline=True, static_argnames=["approx"]) +def is_false(x: Array, tol: float = 0.05, approx: Optional[bool] = None) -> Array: """ Element-wise check if a given truth value can be considered to be false. diff --git a/differt2d/optimize.py b/differt2d/optimize.py index 1950cd7..08417c2 100644 --- a/differt2d/optimize.py +++ b/differt2d/optimize.py @@ -46,6 +46,13 @@ def default_optimizer() -> optax.GradientTransformation: Useful to override the :func:`repr` method in the documentation. + .. note:: + + This optimizer should be a good default choise when used by + :class:`MinPath` as it gave the + best convergence results when compared to other optimizers + provided by :mod:`optax`. + :return: The default optimizer. :Examples: @@ -155,7 +162,7 @@ def minimize_many_random_uniform( :param many: How many times the minimization should be performed. :param kwargs: - Keyword arguments to be passed to :func:`minimize`. + Keyword arguments to be passed to :func:`minimize_random_uniform`. :return: The solution array and the corresponding loss. :Examples: diff --git a/tests/conftest.py b/tests/conftest.py index a72eb24..d53697d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,12 @@ +import jax import pytest @pytest.fixture def seed(): yield 1234 + + +@pytest.fixture +def key(seed): + yield jax.random.PRNGKey(seed) diff --git a/tests/test_logic.py b/tests/test_logic.py index 84856d8..6ace063 100644 --- a/tests/test_logic.py +++ b/tests/test_logic.py @@ -4,7 +4,40 @@ import pytest from jax import disable_jit -from differt2d.logic import activation, enable_approx, is_true +from differt2d.logic import ( + activation, + enable_approx, + greater, + greater_equal, + is_false, + is_true, + less, + less_equal, + logical_and, + logical_not, + logical_or, +) + +approx = pytest.mark.parametrize(("approx",), [(True,), (False,)]) +alpha = pytest.mark.parametrize( + ("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)] +) +function = pytest.mark.parametrize(("function",), [("sigmoid",), ("hard_sigmoid",)]) +tol = pytest.mark.parametrize(("tol",), [(0.05,), (0.5,)]) + + +@pytest.fixture +def x(key): + x = jax.random.uniform(key, (200,)) + yield x + + +@pytest.fixture +def xy(key): + key1, key2 = jax.random.split(key) + x = jax.random.uniform(key1, (200,)) + y = jax.random.uniform(key2, (200,)) + yield x, y def test_enable_approx(): @@ -101,7 +134,7 @@ def test_enable_approx_with_keyword(): ("function", "jax_fun"), [("sigmoid", jax.nn.sigmoid), ("hard_sigmoid", jax.nn.hard_sigmoid)], ) -@pytest.mark.parametrize(("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)]) +@alpha def test_activation(function, jax_fun, alpha): x = jnp.linspace(-5, +5, 200) expected = jax_fun(alpha * x) @@ -117,3 +150,129 @@ def test_invalid_activation(function): with pytest.raises(ValueError) as e: activation(1.0, function=function) assert "Unknown" in str(e) + + +@approx +def test_logical_or(xy, approx): + x, y = xy + if approx: + expected = jnp.maximum(x, y) + else: + expected = jnp.logical_or(x, y) + + got = logical_or(x, y, approx=approx) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +def test_logical_and(xy, approx): + x, y = xy + if approx: + expected = jnp.minimum(x, y) + else: + expected = jnp.logical_and(x, y) + + got = logical_and(x, y, approx=approx) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +def test_logical_not(x, approx): + if approx: + expected = jnp.subtract(1.0, x) + else: + expected = jnp.logical_not(x) + + got = logical_not(x, approx=approx) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +@alpha +@function +def test_greater(xy, approx, alpha, function): + x, y = xy + if approx: + expected = activation(x - y, alpha=alpha, function=function) + else: + expected = jnp.greater(x, y) + + got = greater(x, y, approx=approx, alpha=alpha, function=function) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +@alpha +@function +def test_greater_equal(xy, approx, alpha, function): + x, y = xy + if approx: + expected = activation(x - y, alpha=alpha, function=function) + else: + expected = jnp.greater_equal(x, y) + + got = greater_equal(x, y, approx=approx, alpha=alpha, function=function) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +@alpha +@function +def test_less(xy, approx, alpha, function): + x, y = xy + if approx: + expected = activation(y - x, alpha=alpha, function=function) + else: + expected = jnp.less(x, y) + + got = less(x, y, approx=approx, alpha=alpha, function=function) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +@alpha +@function +def test_less_equal(xy, approx, alpha, function): + x, y = xy + if approx: + expected = activation(y - x, alpha=alpha, function=function) + else: + expected = jnp.less_equal(x, y) + + got = less_equal(x, y, approx=approx, alpha=alpha, function=function) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +@tol +def test_is_true(x, approx, tol): + if approx: + expected = jnp.greater(x, 1.0 - tol) + else: + x = jnp.greater(x, 1.0 - tol) + expected = x + + got = is_true(x, tol=tol, approx=approx) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@approx +@tol +def test_is_false(x, approx, tol): + if approx: + expected = jnp.less(x, tol) + else: + x = jnp.greater(x, tol) + expected = jnp.logical_not(x) + + got = is_false(x, tol=tol, approx=approx) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got)