From 1edb427da2e8aa1a77a55b42b156f13c9ae583d6 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sun, 3 Jul 2022 17:08:23 +0100 Subject: [PATCH] workarounds for XLA bugs + pecularities --- src/Tensor.idr | 16 ++++++++++++---- test/Unit/TestTensor.idr | 6 ------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/Tensor.idr b/src/Tensor.idr index 72accab32..344cac37b 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -1135,11 +1135,19 @@ uniform : (bound, bound' : Tensor shape F64) -> Rand (Tensor shape F64) uniform (MkTensor key) bound bound' = - let MkTensor minval = min bound bound' - MkTensor maxval = max bound bound' + let minval@(MkTensor minvalExpr) = min bound bound' + maxval@(MkTensor maxvalExpr) = max bound bound' in ST $ \(MkTensor initialState) => - let valueState = UniformFloatingPoint key initialState minval maxval shape - in Id (MkTensor $ GetTupleElement 1 valueState, MkTensor $ GetTupleElement 0 valueState) + let valueState = UniformFloatingPoint key initialState minvalExpr maxvalExpr shape + value = MkTensor $ GetTupleElement 0 valueState + -- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663 + -- samples between -inf and 0 should be at -inf, but XLA produces nan + -- similarly, samples in (inf, inf) should be at inf and respectively for -inf + inf = broadcast inf + value = select (minval == - inf && maxval == fill 0) (- inf) value + value = select (minval == inf && maxval == inf) inf value + value = select (minval == - inf && maxval == - inf) (- inf) value + in Id (MkTensor $ GetTupleElement 1 valueState, value) ||| Generate independent and identically distributed (IID) samples from the standard normal ||| distribution. diff --git a/test/Unit/TestTensor.idr b/test/Unit/TestTensor.idr index a811fd661..36efb27db 100644 --- a/test/Unit/TestTensor.idr +++ b/test/Unit/TestTensor.idr @@ -1104,12 +1104,6 @@ uniformForNonFiniteBounds = property $ do seed = fromLiteral seed samples = evalState seed (uniform key (broadcast bound) (broadcast bound')) - -- XLA is inconsistent in how it handles (-inf, 0) and (0, inf) - -- XLA says (-inf, -inf) and (inf, inf) samples are nan. That could be argued since we can't - -- assert inf and inf are ordered and therefore that the bounds are indeed the min and max. - -- That said, that seems like weak argument since it seems more reasonable to interpret them as - -- ordered since there's no way for the user to specify anything else. I'm going to call this - -- a bug in XLA. samples ===# fromLiteral [-inf, inf, nan, -inf, nan, nan, inf, nan, nan] covering