From ccd25a1a152c08e6b2d4692bdb640ec6d28d7d9f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 5 Feb 2025 15:35:58 -0800 Subject: [PATCH] Use assertAllClose to compare float32 arrays PiperOrigin-RevId: 723676676 --- praxis/sample_decode_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/praxis/sample_decode_test.py b/praxis/sample_decode_test.py index 7a554aa6..a927dddf 100644 --- a/praxis/sample_decode_test.py +++ b/praxis/sample_decode_test.py @@ -506,10 +506,10 @@ def decode_fn(model, input_ids, input_paddings): # batch size is 1. self.assertEqual(1, top_candidate_logprobs.shape[0]) self.assertEqual(1, top_candidate_ids.shape[0]) - self.assertArraysEqual( + self.assertAllClose( logprobs, top_candidate_logprobs[0, :, :, :num_per_token_logprobs] ) - self.assertArraysEqual( + self.assertAllClose( ids, top_candidate_ids[0, :, :, :num_per_token_logprobs] )