Skip to content

Commit

Permalink
Remove guards around b/143684500.
Browse files Browse the repository at this point in the history
Several kernels have previously been disabled for std::complex due
to lack of operator support on device.  This is no longer the case,
and these kernels can now be enabled.

PiperOrigin-RevId: 363924797
  • Loading branch information
tensorflower-gardener committed Mar 19, 2021
1 parent e95dc8e commit ca9dcdd
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 16 deletions.
6 changes: 1 addition & 5 deletions keras/optimizer_v2/adadelta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@

from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import test_util
from keras import combinations
from keras.optimizer_v2 import adadelta

_DATA_TYPES = [tf.half, tf.float32, tf.float64]
# TODO(b/143684500): Eigen to support complex sqrt
if not test_util.IsBuiltWithNvcc():
_DATA_TYPES += [tf.complex64, tf.complex128]
_DATA_TYPES = [tf.half, tf.float32, tf.float64, tf.complex64, tf.complex128]


class AdadeltaOptimizerTest(tf.test.TestCase, parameterized.TestCase):
Expand Down
6 changes: 1 addition & 5 deletions keras/optimizer_v2/adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@

from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import test_util
from keras import combinations
from keras.optimizer_v2 import adagrad
from keras.optimizer_v2 import learning_rate_schedule

_DATA_TYPES = [tf.half, tf.float32, tf.float64]
# TODO(b/143684500): Eigen to support complex sqrt
if not test_util.IsBuiltWithNvcc():
_DATA_TYPES += [tf.complex64, tf.complex128]
_DATA_TYPES = [tf.half, tf.float32, tf.float64, tf.complex64, tf.complex128]


def adagrad_update_numpy(param, accum, g_t, lr=0.001, epsilon=1e-7):
Expand Down
7 changes: 1 addition & 6 deletions keras/optimizer_v2/rmsprop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
from keras.optimizer_v2 import learning_rate_schedule
from keras.optimizer_v2 import rmsprop

_DATA_TYPES = [tf.half, tf.float32, tf.float64]
# TODO(b/143684500): Eigen to support complex sqrt
if not test_util.IsBuiltWithNvcc():
_DATA_TYPES += [tf.complex64, tf.complex128]
_DATA_TYPES = [tf.half, tf.float32, tf.float64, tf.complex64, tf.complex128]

_TEST_PARAM_VALUES = [
# learning_rate, rho, momentum, epsilon, centered
Expand Down Expand Up @@ -346,8 +343,6 @@ def testMinimizeSparseResourceVariableCentered(self):
# TODO(tanzheny, omalleyt): Fix test in eager mode.
with tf.Graph().as_default():
for dtype in _DATA_TYPES:
if test_util.is_xla_enabled() and dtype.is_complex:
self.skipTest("b/143578550")
var0 = tf.Variable([[1.0, 2.0]], dtype=dtype)
x = tf.constant([[4.0], [5.0]], dtype=dtype)

Expand Down

0 comments on commit ca9dcdd

Please sign in to comment.