Skip to content

Commit

Permalink
Better cross_entropy test (facebookresearch#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
artkorenev authored Feb 24, 2023
1 parent 7c25e76 commit a053879
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/test_cross_entropy_fairinternal.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,16 @@ def test_cross_entropy_distribution(dtype_str, B_M_K, student_scale_bias) -> Non
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
@pytest.mark.parametrize("ignore_index", [-100, 5])
@pytest.mark.parametrize("bw_inplace", [False, True])
@pytest.mark.parametrize("reference_fp32", [False, True])
def test_softmax_cross_entropy(
dtype_str, B_M_K, label_smoothing, ignore_index, bw_inplace
dtype_str, B_M_K, label_smoothing, ignore_index, bw_inplace, reference_fp32
) -> None:
if dtype_str == "f16" and label_smoothing == 0.1 and B_M_K == (12, 13, 131072):
if (
(not reference_fp32)
and dtype_str == "f16"
and label_smoothing == 0.1
and B_M_K == (12, 13, 131072)
):
pytest.skip(
"When matrices are large and FP16 is used with label smoothing,"
"original Pytorch cross-entropy (the reference) returns inf"
Expand All @@ -106,13 +112,14 @@ def test_softmax_cross_entropy(
ignore_index=ignore_index,
bw_inplace=bw_inplace,
)

ref_loss = F.cross_entropy(
input_ref,
input_ref.float() if reference_fp32 else input_ref,
labels,
label_smoothing=label_smoothing,
ignore_index=ignore_index,
reduction="none",
)
).to(dtype)

assert_allclose(loss.float(), ref_loss.float(), "loss", atol=atol, rtol=rtol)

Expand Down

0 comments on commit a053879

Please sign in to comment.