Skip to content

Commit

Permalink
Merge branch 'add-gn' of https://github.com/gbruno16/optax into add-gn
Browse files Browse the repository at this point in the history
  • Loading branch information
gbruno16 committed Apr 9, 2024
2 parents 37bdd48 + 226d982 commit c3fe42e
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 117 deletions.
48 changes: 47 additions & 1 deletion optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,52 @@ def binary_sparsemax_loss(logits, labels):
return sparsemax_loss(logits, labels)


@jax.custom_jvp
def weighted_logsoftmax(x: chex.Array, weights: chex.Array) -> chex.Array:
r"""Weighted logsoftmax.
Computes
.. math::
(w_i \log(\exp x_i /(\sum_i \exp x_i )) )_{i=1}^n
for :math:`x` the input ``x``, :math:`w` the ``weights``.
For :math:`w_i = 0`, :math:`x_i=-\infty`, this implementation ensures that the
output is 0 and not nan at the ith entry following the convention that
:math:`0 \log 0 = 0`.
Args:
x: input array.
weights: weights.
Returns:
logsoftmax of x multiplied elementwise by weights
"""
logsoftmax_x = jax.nn.log_softmax(x, axis=-1)
return jnp.where(
weights != 0.0, weights * logsoftmax_x, jnp.zeros_like(logsoftmax_x)
)


def _weighted_logsoftmax_jvp(primals, tangents):
"""Custom JVP of weighted logsoftmax."""
(x, weights) = primals
(x_dot, weights_dot) = tangents
logsoftmax_x = jax.nn.log_softmax(x, axis=-1)
result = jnp.where(
weights != 0.0, weights * logsoftmax_x, jnp.zeros_like(logsoftmax_x)
)
out_tangents = (
weights * x_dot
- weights
* jnp.sum(x_dot * jax.nn.softmax(x, axis=-1), axis=-1, keepdims=True)
+ weights_dot * logsoftmax_x
)
return result, out_tangents


weighted_logsoftmax.defjvp(_weighted_logsoftmax_jvp)


def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
Expand All @@ -159,7 +205,7 @@ def softmax_cross_entropy(
distributions, with shape `[...]`.
"""
chex.assert_type([logits], float)
return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
return -jnp.sum(weighted_logsoftmax(logits, labels), axis=-1)


def softmax_cross_entropy_with_integer_labels(
Expand Down
Loading

0 comments on commit c3fe42e

Please sign in to comment.