From a053879f58997d6a2ee18b7954e1e0d0388eeb1f Mon Sep 17 00:00:00 2001 From: Artem Date: Fri, 24 Feb 2023 15:38:15 +0000 Subject: [PATCH] Better cross_entropy test (#483) --- tests/test_cross_entropy_fairinternal.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_cross_entropy_fairinternal.py b/tests/test_cross_entropy_fairinternal.py index 4308fa385a..8ee7557dfa 100644 --- a/tests/test_cross_entropy_fairinternal.py +++ b/tests/test_cross_entropy_fairinternal.py @@ -76,10 +76,16 @@ def test_cross_entropy_distribution(dtype_str, B_M_K, student_scale_bias) -> Non @pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) @pytest.mark.parametrize("ignore_index", [-100, 5]) @pytest.mark.parametrize("bw_inplace", [False, True]) +@pytest.mark.parametrize("reference_fp32", [False, True]) def test_softmax_cross_entropy( - dtype_str, B_M_K, label_smoothing, ignore_index, bw_inplace + dtype_str, B_M_K, label_smoothing, ignore_index, bw_inplace, reference_fp32 ) -> None: - if dtype_str == "f16" and label_smoothing == 0.1 and B_M_K == (12, 13, 131072): + if ( + (not reference_fp32) + and dtype_str == "f16" + and label_smoothing == 0.1 + and B_M_K == (12, 13, 131072) + ): pytest.skip( "When matrices are large and FP16 is used with label smoothing," "original Pytorch cross-entropy (the reference) returns inf" @@ -106,13 +112,14 @@ def test_softmax_cross_entropy( ignore_index=ignore_index, bw_inplace=bw_inplace, ) + ref_loss = F.cross_entropy( - input_ref, + input_ref.float() if reference_fp32 else input_ref, labels, label_smoothing=label_smoothing, ignore_index=ignore_index, reduction="none", - ) + ).to(dtype) assert_allclose(loss.float(), ref_loss.float(), "loss", atol=atol, rtol=rtol)