Skip to content

Commit

Permalink
Add a unit test to cover #15012.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 387844394
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Jul 30, 2021
1 parent b693bb8 commit 2d79a2e
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions keras/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self):
result = f([t_val, p_val])
self.assertArrayNear(result, [.002, 0, .17], 1e-3)

@combinations.generate(combinations.combine(mode=['eager']))
def test_sparse_categorical_crossentropy_with_float16(self):
# See https://github.com/keras-team/keras/issues/15012 for more details.
# we don't cast y_true to have same dtype as y_pred, since y_pred could be
# float16 which has a small upbound, and the casting could cause an
# underflow. The y_true will be used as int64 anyway.

# create 2 observations with 2049 labels, since 2048 is the largest number
# for float16
y_true = [0, 2049]
# should result in a loss close to 0 since predicting y_true perfectly
y_pred = np.zeros((2, 2050))
y_pred[0][0] = 1
y_pred[1][2049] = 1
y_pred_16 = tf.convert_to_tensor(y_pred, dtype=tf.float16)

# If we did a cast for y_true to float16 in SparseCategoricalCrossentropy,
# then the loss will not be zero.
scce = losses.SparseCategoricalCrossentropy()
self.assertAllClose(scce(y_true, y_pred_16).numpy(), 0.0, atol=1e-3)

@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_binary_crossentropy_loss(self):
target = backend.variable(np.random.randint(0, 1, (5, 1)))
Expand Down

0 comments on commit 2d79a2e

Please sign in to comment.