Skip to content

Commit

Permalink
Add ClippedBinaryCrossEntropy test for all float dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
julienperichon committed Jul 29, 2022
1 parent 6252b65 commit 18a6434
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions keras_fsl/losses/tests/gram_matrix_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,13 @@ def test_clipped_loss_should_equal_literal_calculation(self, y_true, adjacency_m
)
np.testing.assert_almost_equal(tf_loss, np_loss, decimal=5)

def test_clipped_loss_computes_in_float16(self, y_true, y_pred):
@staticmethod
@pytest.mark.parametrize("dtype_policy", (tf.float16, tf.bfloat16, tf.float32, tf.float64))
def test_clipped_loss_computes_in_all_float_dtypes(dtype_policy, y_true, y_pred):
y_true_tensor = tf.convert_to_tensor(y_true)
y_pred_tensor = tf.convert_to_tensor(y_pred)
ClippedBinaryCrossentropy(lower=0.05, upper=0.95)(
tf.convert_to_tensor(y_true, tf.float16), tf.convert_to_tensor(y_pred, tf.float16)
tf.cast(y_true_tensor, dtype=dtype_policy), tf.cast(y_pred_tensor, dtype=dtype_policy)
)

def test_max_loss_should_equal_literal_calculation(self, y_true, adjacency_matrix, y_pred):
Expand Down

0 comments on commit 18a6434

Please sign in to comment.