Skip to content

Commit

Permalink
Merge pull request #14 from jeertmans/activation
Browse files Browse the repository at this point in the history
feat(lib): change sigmoid to more general activation function
  • Loading branch information
jeertmans authored Jul 26, 2023
2 parents dd52f79 + 86b6dd1 commit 0afdf70
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 37 deletions.
91 changes: 55 additions & 36 deletions differt2d/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

__all__ = [
"activation",
"disable_approx",
"enable_approx",
"greater",
Expand All @@ -33,12 +34,11 @@
"logical_and",
"logical_not",
"logical_or",
"sigmoid",
]

from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

import jax
import jax.numpy as jnp
Expand All @@ -49,7 +49,7 @@
_enable_approx = jax.config.define_bool_state(
name="jax_enable_approx",
default=True,
help=("Enable approximation using sigmoids."),
help=("Enable approximation using some activation function."),
)

jit_approx = partial(jax.jit, inline=True, static_argnames=["approx"])
Expand Down Expand Up @@ -175,32 +175,34 @@ def disable_approx(disable: bool = True):
yield


@partial(jax.jit, inline=True)
def sigmoid(x: Array, *, lambda_: float = 100.0) -> Array:
@partial(jax.jit, inline=True, static_argnames=("function",))
def activation(
x: Array,
alpha: float = 1e2,
function: Literal["sigmoid", "hard_sigmoid"] = "sigmoid",
) -> Array:
r"""
Element-wise function for approximating a discrete transition between 0 and 1,
with a smoothed transition.
with a smoothed transition centered at :python:`x = 0.0`.
.. math::
\text{sigmoid}(x;\lambda) = \frac{1}{1 + e^{-\lambda x}},
Depending on the ``function`` argument, the activation function has the
following definition:
where :math:`\lambda` (:code:`lambda_`) is a slope parameter.
.. math::
\text{sigmoid}(x;\alpha) = \frac{1}{1 + e^{-\alpha x}},
See :func:`jax.nn.sigmoid` for more details.
or
.. note::
.. math::
\text{hard_sigmoid}(x;\alpha) = \frac{\text{relu6}(\alpha x+3)}{6},
Using the above definition for the sigmoid will produce
undesirable effects when computing its gradient. This is why we rely
on JAX's implementation, that does not produce :code:`NaN` values
when :code:`x` is small.
where :math:`\alpha` (:code:`alpha`) is a slope parameter.
You can read more about this in
:sothread:`questions/68290850/jax-autograd-of-a-sigmoid-always-returns-nan`.
See :func:`jax.nn.sigmoid` or :func:`jax.nn.hard_sigmoid` for more details.
:param x: The input array.
:param `lambda_`: The slope parameter.
:return: The corresponding sigmoid values.
:param alpha: The slope parameter.
:return: The corresponding values.
:EXAMPLES:
Expand All @@ -209,20 +211,35 @@ def sigmoid(x: Array, *, lambda_: float = 100.0) -> Array:
import matplotlib.pyplot as plt
import numpy as np
from differt2d.logic import sigmoid
from differt2d.logic import activation
from jax import grad, vmap
x = np.linspace(-5, +5, 200)
for lambda_ in [1, 10, 100]:
y = sigmoid(x, lambda_=lambda_)
_ = plt.plot(x, y, "--", label=f"$\\lambda = {lambda_}$")
_, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=[6.4, 8])
for function in ["sigmoid", "hard_sigmoid"]:
def f(x):
return activation(x, alpha=1.0, function=function)
y = f(x)
dydx = vmap(grad(f))(x)
_ = ax1.plot(x, y, "--", label=f"{function}")
_ = ax2.plot(x, dydx, "-", label=f"{function}")
plt.xlabel("$x$")
plt.ylabel(r"sigmoid$(x;\lambda)$")
ax2.set_xlabel("$x$")
ax1.set_ylabel("$f(x)$")
ax2.set_ylabel(r"$\frac{\partial f(x)}{\partial x}$")
plt.legend()
plt.tight_layout()
plt.show()
"""
return jax.nn.sigmoid(lambda_ * x)
if function == "sigmoid":
return jax.nn.sigmoid(alpha * x)
elif function == "hard_sigmoid":
return jax.nn.hard_sigmoid(alpha * x)
else:
raise ValueError(f"Unknown function '{function}'")


@jit_approx
Expand Down Expand Up @@ -291,7 +308,7 @@ def greater(
Element-wise logical :python:`x > y`.
Calls :func:`jax.numpy.subtract`
then :func:`sigmoid` if approximation is enabled,
then :func:`activation` if approximation is enabled,
:func:`jax.numpy.greater` otherwise.
:param x: The first input array.
Expand All @@ -301,7 +318,7 @@ def greater(
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return sigmoid(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y)
return activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater(x, y)


@jit_approx
Expand All @@ -312,7 +329,7 @@ def greater_equal(
Element-wise logical :python:`x >= y`.
Calls :func:`jax.numpy.subtract`
then :func:`sigmoid` if approximation is enabled,
then :func:`activation` if approximation is enabled,
:func:`jax.numpy.greater_equal` otherwise.
:param x: The first input array.
Expand All @@ -322,7 +339,9 @@ def greater_equal(
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return sigmoid(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y)
return (
activation(jnp.subtract(x, y), **kwargs) if approx else jnp.greater_equal(x, y)
)


@jit_approx
Expand All @@ -331,7 +350,7 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra
Element-wise logical :python:`x < y`.
Calls :func:`jax.numpy.subtract` (arguments swapped)
then :func:`sigmoid` if approximation is enabled,
then :func:`activation` if approximation is enabled,
:func:`jax.numpy.less` otherwise.
:param x: The first input array.
Expand All @@ -341,7 +360,7 @@ def less(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -> Arra
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return sigmoid(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y)
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less(x, y)


@jit_approx
Expand All @@ -350,7 +369,7 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -
Element-wise logical :python:`x <= y`.
Calls :func:`jax.numpy.subtract` (arguments swapped)
then :func:`sigmoid` if approximation is enabled,
then :func:`activation` if approximation is enabled,
:func:`jax.numpy.less_equal` otherwise.
:param x: The first input array.
Expand All @@ -360,11 +379,11 @@ def less_equal(x: Array, y: Array, *, approx: Optional[bool] = None, **kwargs) -
"""
if approx is None:
approx = jax.config.jax_enable_approx # type: ignore[attr-defined]
return sigmoid(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y)
return activation(jnp.subtract(y, x), **kwargs) if approx else jnp.less_equal(x, y)


@jit_approx
def is_true(x: Array, *, tol: float = 0.05, approx: Optional[bool] = None) -> Array:
def is_true(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array:
"""
Element-wise check if a given truth value can be considered to be true.
Expand All @@ -383,7 +402,7 @@ def is_true(x: Array, *, tol: float = 0.05, approx: Optional[bool] = None) -> Ar


@jit_approx
def is_false(x: Array, *, tol: float = 0.05, approx: Optional[bool] = None) -> Array:
def is_false(x: Array, *, tol: float = 0.01, approx: Optional[bool] = None) -> Array:
"""
Element-wise check if a given truth value can be considered to be false.
Expand Down
25 changes: 24 additions & 1 deletion tests/test_logic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import chex
import jax
import jax.numpy as jnp
import pytest
from jax import disable_jit

from differt2d.logic import enable_approx, is_true
from differt2d.logic import activation, enable_approx, is_true


def test_enable_approx():
Expand Down Expand Up @@ -94,3 +95,25 @@ def test_enable_approx_with_keyword():
got = is_true(True, approx=False)
chex.assert_trees_all_equal(expected, got)
chex.assert_trees_all_equal_shapes_and_dtypes(expected, got)


@pytest.mark.parametrize(
("function", "jax_fun"),
[("sigmoid", jax.nn.sigmoid), ("hard_sigmoid", jax.nn.hard_sigmoid)],
)
@pytest.mark.parametrize(("alpha",), [(1e-3,), (1e-2,), (1e-1,), (1e-0,), (1e1,)])
def test_activation(function, jax_fun, alpha):
x = jnp.linspace(-5, +5, 200)
expected = jax_fun(alpha * x)
got = activation(x, alpha=alpha, function=function)
chex.assert_trees_all_close(expected, got)
chex.assert_trees_all_equal_shapes_and_dtypes(expected, got)


@pytest.mark.parametrize(
("function",), [("relu",), ("SIGMOID",), ("HARD_SIGMOID",), ("hard-sigmoid",)]
)
def test_invalid_activation(function):
with pytest.raises(ValueError) as e:
activation(1.0, function=function)
assert "Unknown" in str(e)

0 comments on commit 0afdf70

Please sign in to comment.