From 584832abbbb12063ec5581dcb7eec22a211cb472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 26 Jul 2023 14:05:46 +0200 Subject: [PATCH 1/4] feat(lib): change sigmoid to more general activation function --- differt2d/logic.py | 85 ++++++++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 36 deletions(-) diff --git a/differt2d/logic.py b/differt2d/logic.py index 41f69a5..fe919ee 100644 --- a/differt2d/logic.py +++ b/differt2d/logic.py @@ -22,6 +22,7 @@ from __future__ import annotations __all__ = [ + "activation", "disable_approx", "enable_approx", "greater", @@ -33,12 +34,11 @@ "logical_and", "logical_not", "logical_or", - "sigmoid", ] from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Literal import jax import jax.numpy as jnp @@ -49,7 +49,7 @@ _enable_approx = jax.config.define_bool_state( name="jax_enable_approx", default=True, - help=("Enable approximation using sigmoids."), + help=("Enable approximation using some activation function."), ) jit_approx = partial(jax.jit, inline=True, static_argnames=["approx"]) @@ -175,32 +175,30 @@ def disable_approx(disable: bool = True): yield -@partial(jax.jit, inline=True) -def sigmoid(x: Array, *, lambda_: float = 100.0) -> Array: +@partial(jax.jit, inline=True, static_argnames=("function",)) +def activation(x: Array, alpha: float = 1e2, function: Literal["sigmoid", "hard_sigmoid"] = "sigmoid") -> Array: r""" Element-wise function for approximating a discrete transition between 0 and 1, - with a smoothed transition. + with a smoothed transition centered at :python:`x = 0.0`. - .. math:: - \text{sigmoid}(x;\lambda) = \frac{1}{1 + e^{-\lambda x}}, + Depending on the ``function`` argument, the activation function has the + following definition: - where :math:`\lambda` (:code:`lambda_`) is a slope parameter. + .. math:: + \text{sigmoid}(x;\alpha) = \frac{1}{1 + e^{-\alpha x}}, - See :func:`jax.nn.sigmoid` for more details. + or - .. note:: + .. math:: + \text{hard_sigmoid}(x;\alpha) = \frac{\text{relu6}(\alpha x+3)}{6}, - Using the above definition for the sigmoid will produce - undesirable effects when computing its gradient. This is why we rely - on JAX's implementation, that does not produce :code:`NaN` values - when :code:`x` is small. + where :math:`\alpha` (:code:`alpha`) is a slope parameter. - You can read more about this in - :sothread:`questions/68290850/jax-autograd-of-a-sigmoid-always-returns-nan`. + See :func:`jax.nn.sigmoid` or :func:`jax.nn.hard_sigmoid` for more details. :param x: The input array. - :param `lambda_`: The slope parameter. - :return: The corresponding sigmoid values. + :param alpha: The slope parameter. + :return: The corresponding values. :EXAMPLES: @@ -209,20 +207,35 @@ def sigmoid(x: Array, *, lambda_: float = 100.0) -> Array: import matplotlib.pyplot as plt import numpy as np - from differt2d.logic import sigmoid + from differt2d.logic import activation + from jax import grad, vmap x = np.linspace(-5, +5, 200) - for lambda_ in [1, 10, 100]: - y = sigmoid(x, lambda_=lambda_) - _ = plt.plot(x, y, "--", label=f"$\\lambda = {lambda_}$") + _, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[6.4, 8]) + + for function in ["sigmoid", "hard_sigmoid"]: + def f(x): + return activation(x, alpha=1.0, function=function) + + y = f(x) + dydx = vmap(grad(f))(x) + _ = ax1.plot(x, y, "--", label=f"{function}") + _ = ax2.plot(x, dydx, "-", label=f"{function}") - plt.xlabel("$x$") - plt.ylabel(r"sigmoid$(x;\lambda)$") + ax2.set_xlabel("$x$") + ax1.set_ylabel("$f(x)$") + ax2.set_ylabel(r"$\frac{\partial f(x)}{\partial x}$") plt.legend() + plt.tight_layout() plt.show() """ - return jax.nn.sigmoid(lambda_ * x) + if function == "sigmoid": + return jax.nn.sigmoid(alpha * x) + elif function == "hard_sigmoid": + return jax.nn.hard_sigmoid(alpha * x) + else: + raise ValueError(f"Unknown function '{function}'") @jit_approx @@ -291,7 +304,7 @@ def greater( Element-wise logical :python:`x > y`. Calls :func:`jax.numpy.subtract` - then :func:`sigmoid` if approximation is enabled, + then :func:`activation` if approximation is enabled, :func:`jax.numpy.greater` otherwise. :param x: The first input array. @@ -301,7 +314,7 @@ def greater( """ if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] - return sigmoid(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y) + return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y) @jit_approx @@ -312,7 +325,7 @@ def greater_equal( Element-wise logical :python:`x >= y`. Calls :func:`jax.numpy.subtract` - then :func:`sigmoid` if approximation is enabled, + then :func:`activation` if approximation is enabled, :func:`jax.numpy.greater_equal` otherwise. :param x: The first input array. @@ -322,7 +335,7 @@ def greater_equal( """ if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] - return sigmoid(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y) + return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y) @jit_approx @@ -331,7 +344,7 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra Element-wise logical :python:`x < y`. Calls :func:`jax.numpy.subtract` (arguments swapped) - then :func:`sigmoid` if approximation is enabled, + then :func:`activation` if approximation is enabled, :func:`jax.numpy.less` otherwise. :param x: The first input array. @@ -341,7 +354,7 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra """ if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] - return sigmoid(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y) + return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y) @jit_approx @@ -350,7 +363,7 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) - Element-wise logical :python:`x <= y`. Calls :func:`jax.numpy.subtract` (arguments swapped) - then :func:`sigmoid` if approximation is enabled, + then :func:`activation` if approximation is enabled, :func:`jax.numpy.less_equal` otherwise. :param x: The first input array. @@ -360,11 +373,11 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) - """ if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] - return sigmoid(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y) + return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y) @jit_approx -def is_true(x: Array, *, tol: float = 0.05, approx: Optional[bool] = None) -> Array: +def is_true(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array: """ Element-wise check if a given truth value can be considered to be true. @@ -383,7 +396,7 @@ def is_true(x: Array, *, tol: float = 0.05, approx: Optional[bool] = None) -> Ar @jit_approx -def is_false(x: Array, *, tol: float = 0.05, approx: Optional[bool] = None) -> Array: +def is_false(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array: """ Element-wise check if a given truth value can be considered to be false. From 70df53cb87da393854ba0a7b11e1ba0a7fd40188 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:06:06 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- differt2d/logic.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/differt2d/logic.py b/differt2d/logic.py index fe919ee..2dc4ed9 100644 --- a/differt2d/logic.py +++ b/differt2d/logic.py @@ -38,7 +38,7 @@ from contextlib import contextmanager from functools import partial -from typing import TYPE_CHECKING, Optional, Literal +from typing import TYPE_CHECKING, Literal, Optional import jax import jax.numpy as jnp @@ -176,7 +176,11 @@ def disable_approx(disable: bool = True): @partial(jax.jit, inline=True, static_argnames=("function",)) -def activation(x: Array, alpha: float = 1e2, function: Literal["sigmoid", "hard_sigmoid"] = "sigmoid") -> Array: +def activation( + x: Array, + alpha: float = 1e2, + function: Literal["sigmoid", "hard_sigmoid"] = "sigmoid", +) -> Array: r""" Element-wise function for approximating a discrete transition between 0 and 1, with a smoothed transition centered at :python:`x = 0.0`. @@ -335,7 +339,9 @@ def greater_equal( """ if approx is None: approx = jax.config.jax_enable_approx # type: ignore[attr-defined] - return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y) + return ( + activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y) + ) @jit_approx From 3496c0e66a789bb7b0dea283acca33d6a4aa3168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 26 Jul 2023 14:17:11 +0200 Subject: [PATCH 3/4] chore(tests): test activation function --- tests/test_logic.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_logic.py b/tests/test_logic.py index 2a8bae9..0aea7cb 100644 --- a/tests/test_logic.py +++ b/tests/test_logic.py @@ -1,9 +1,10 @@ import chex import jax import jax.numpy as jnp +import pytest from jax import disable_jit -from differt2d.logic import enable_approx, is_true +from differt2d.logic import activation, enable_approx, is_true def test_enable_approx(): @@ -94,3 +95,16 @@ def test_enable_approx_with_keyword(): got = is_true(True, approx=False) chex.assert_trees_all_equal(expected, got) chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@pytest.mark.parametrize( + ("function", "jax_fun"), + [("sigmoid", jax.nn.sigmoid), ("hard_sigmoid", jax.nn.hard_sigmoid)], +) +@pytest.mark.parametrize(("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)]) +def test_activation(function, jax_fun, alpha): + x = jnp.linspace(-5, +5, 200) + expected = jax_fun(alpha * x) + got = activation(x, alpha=alpha, function=function) + chex.assert_trees_all_close(expected, got) + chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) From 86b6dd1516c85ed563c7187d95b5d45bd37b6e51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 26 Jul 2023 14:22:32 +0200 Subject: [PATCH 4/4] chore(tests): check invalid inputs --- tests/test_logic.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_logic.py b/tests/test_logic.py index 0aea7cb..84856d8 100644 --- a/tests/test_logic.py +++ b/tests/test_logic.py @@ -108,3 +108,12 @@ def test_activation(function, jax_fun, alpha): got = activation(x, alpha=alpha, function=function) chex.assert_trees_all_close(expected, got) chex.assert_trees_all_equal_shapes_and_dtypes(expected, got) + + +@pytest.mark.parametrize( + ("function",), [("relu",), ("SIGMOID",), ("HARD_SIGMOID",), ("hard-sigmoid",)] +) +def test_invalid_activation(function): + with pytest.raises(ValueError) as e: + activation(1.0, function=function) + assert "Unknown" in str(e)