Skip to content

Commit

Permalink
Improve error messages in Keras activations / constraints / metrics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 387713769
  • Loading branch information
fchollet authored and tensorflower-gardener committed Jul 30, 2021
1 parent eef3ad0 commit b693bb8
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 38 deletions.
5 changes: 2 additions & 3 deletions keras/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def softmax(x, axis=-1):
output = e / s
else:
raise ValueError('Cannot apply softmax to a tensor that is 1D. '
'Received input: %s' % (x,))
f'Received input: {x}')

# Cache the logits to use for crossentropy loss.
output._keras_logits = x # pylint: disable=protected-access
Expand Down Expand Up @@ -600,5 +600,4 @@ def get(identifier):
return identifier
else:
raise TypeError(
'Could not interpret activation function identifier: {}'.format(
identifier))
f'Could not interpret activation function identifier: {identifier}')
13 changes: 9 additions & 4 deletions keras/backend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def set_floatx(value):
ValueError: In case of invalid value.
"""
global _FLOATX
if value not in {'float16', 'float32', 'float64'}:
raise ValueError('Unknown floatx type: ' + str(value))
accepted_dtypes = {'float16', 'float32', 'float64'}
if value not in accepted_dtypes:
raise ValueError(
f'Unknown `floatx` value: {value}. Expected one of {accepted_dtypes}')
_FLOATX = str(value)


Expand Down Expand Up @@ -142,6 +144,9 @@ def set_image_data_format(data_format):
ValueError: In case of invalid `data_format` value.
"""
global _IMAGE_DATA_FORMAT
if data_format not in {'channels_last', 'channels_first'}:
raise ValueError('Unknown data_format: ' + str(data_format))
accepted_formats = {'channels_last', 'channels_first'}
if data_format not in accepted_formats:
raise ValueError(
f'Unknown `data_format`: {data_format}. '
f'Expected one of {accepted_formats}')
_IMAGE_DATA_FORMAT = str(data_format)
8 changes: 5 additions & 3 deletions keras/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __call__(self, w):
w_shape = w.shape
if w_shape.rank is None or w_shape.rank != 4:
raise ValueError(
'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
'The weight tensor must have rank 4. '
f'Received weight tensor with shape: {w_shape}')

height, width, channels, kernels = w_shape
w = backend.reshape(w, (height, width, channels * kernels))
Expand Down Expand Up @@ -332,6 +333,7 @@ def deserialize(config, custom_objects=None):

@keras_export('keras.constraints.get')
def get(identifier):
"""Retrieves a Keras constraint function."""
if identifier is None:
return None
if isinstance(identifier, dict):
Expand All @@ -342,5 +344,5 @@ def get(identifier):
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret constraint identifier: ' +
str(identifier))
raise ValueError(
f'Could not interpret constraint function identifier: {identifier}')
59 changes: 38 additions & 21 deletions keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ def update_state(self, values, sample_weight=None):
values = tf.cast(values, self._dtype)
except (ValueError, TypeError):
msg = ('The output of a metric function can only be a single Tensor. '
'Got: %s' % (values,))
f'Received: {values}. ')
if isinstance(values, dict):
msg += ('. To return a dict of values, implement a custom Metric '
msg += ('To return a dict of values, implement a custom Metric '
'subclass.')
raise RuntimeError(msg)
if sample_weight is not None:
Expand Down Expand Up @@ -438,7 +438,8 @@ def update_state(self, values, sample_weight=None):
num_values = tf.reduce_sum(sample_weight)
else:
raise NotImplementedError(
'reduction [%s] not implemented' % self.reduction)
f'Reduction "{self.reduction}" not implemented. Expected '
'"sum", "weighted_mean", or "sum_over_batch_size".')

with tf.control_dependencies([update_total_op]):
return self.count.assign_add(num_values)
Expand All @@ -453,7 +454,8 @@ def result(self):
return tf.math.divide_no_nan(self.total, self.count)
else:
raise NotImplementedError(
'reduction [%s] not implemented' % self.reduction)
f'Reduction "{self.reduction}" not implemented. Expected '
'"sum", "weighted_mean", or "sum_over_batch_size".')


@keras_export('keras.metrics.Sum')
Expand Down Expand Up @@ -1530,7 +1532,9 @@ def __init__(self,
dtype=None):
super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
if num_thresholds <= 0:
raise ValueError('`num_thresholds` must be > 0.')
raise ValueError(
'Argument `num_thresholds` must be an integer > 0. '
f'Received: num_thresholds={num_thresholds}')
self.value = value
self.class_id = class_id
self.true_positives = self.add_weight(
Expand Down Expand Up @@ -1689,7 +1693,9 @@ def __init__(self,
name=None,
dtype=None):
if specificity < 0 or specificity > 1:
raise ValueError('`specificity` must be in the range [0, 1].')
raise ValueError(
'Argument `specificity` must be in the range [0, 1]. '
f'Received: specificity={specificity}')
self.specificity = specificity
self.num_thresholds = num_thresholds
super(SensitivityAtSpecificity, self).__init__(
Expand Down Expand Up @@ -1781,7 +1787,9 @@ def __init__(self,
name=None,
dtype=None):
if sensitivity < 0 or sensitivity > 1:
raise ValueError('`sensitivity` must be in the range [0, 1].')
raise ValueError(
'Argument `sensitivity` must be in the range [0, 1]. '
f'Received: sensitivity={sensitivity}')
self.sensitivity = sensitivity
self.num_thresholds = num_thresholds
super(SpecificityAtSensitivity, self).__init__(
Expand Down Expand Up @@ -1865,7 +1873,9 @@ def __init__(self,
name=None,
dtype=None):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
raise ValueError(
'Argument `recall` must be in the range [0, 1]. '
f'Received: recall={recall}')
self.recall = recall
self.num_thresholds = num_thresholds
super(PrecisionAtRecall, self).__init__(
Expand Down Expand Up @@ -1949,7 +1959,9 @@ def __init__(self,
name=None,
dtype=None):
if precision < 0 or precision > 1:
raise ValueError('`precision` must be in the range [0, 1].')
raise ValueError(
'Argument `precision` must be in the range [0, 1]. '
f'Received: precision={precision}')
self.precision = precision
self.num_thresholds = num_thresholds
super(RecallAtPrecision, self).__init__(
Expand Down Expand Up @@ -2105,15 +2117,16 @@ def __init__(self,
# Validate configurations.
if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
metrics_utils.AUCCurve):
raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
curve, list(metrics_utils.AUCCurve)))
raise ValueError(
f'Invalid `curve` argument value "{curve}". '
f'Expected one of: {list(metrics_utils.AUCCurve)}')
if isinstance(
summation_method,
metrics_utils.AUCSummationMethod) and summation_method not in list(
metrics_utils.AUCSummationMethod):
raise ValueError(
'Invalid summation method: "{}". Valid options are: "{}"'.format(
summation_method, list(metrics_utils.AUCSummationMethod)))
f'Invalid `summation_method` argument value "{summation_method}". '
f'Expected one of: {list(metrics_utils.AUCSummationMethod)}')

# Update properties.
self._init_from_thresholds = thresholds is not None
Expand All @@ -2126,7 +2139,8 @@ def __init__(self,
np.array([0.0] + thresholds + [1.0])))
else:
if num_thresholds <= 1:
raise ValueError('`num_thresholds` must be > 1.')
raise ValueError('Argument `num_thresholds` must be an integer > 1. '
f'Received: num_thresholds={num_thresholds}')

# Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
# (0, 1).
Expand Down Expand Up @@ -2188,8 +2202,10 @@ def _build(self, shape):
"""Initialize TP, FP, TN, and FN tensors, given the shape of the data."""
if self.multi_label:
if shape.ndims != 2:
raise ValueError('`y_true` must have rank=2 when `multi_label` is '
'True. Found rank %s.' % shape.ndims)
raise ValueError(
'`y_true` must have rank 2 when `multi_label=True`. '
f'Found rank {shape.ndims}. '
f'Full shape received for `y_true`: {shape}')
self._num_labels = shape[1]
variable_shape = tf.TensorShape(
[tf.compat.v1.Dimension(self.num_thresholds), self._num_labels])
Expand Down Expand Up @@ -3136,9 +3152,10 @@ def update_state(self, values, sample_weight=None):
if not self._built:
self._build(values.shape)
elif values.shape != self._shape:
raise ValueError('MeanTensor input values must always have the same '
'shape. Expected shape (set during the first call): {}. '
'Got: {}'.format(self._shape, values.shape))
raise ValueError(
'MeanTensor input values must always have the same '
f'shape. Expected shape (set during the first call): {self._shape}. '
f'Got: {values.shape}.')

num_values = tf.ones_like(values)
if sample_weight is not None:
Expand Down Expand Up @@ -3168,7 +3185,7 @@ def update_state(self, values, sample_weight=None):
def result(self):
if not self._built:
raise ValueError(
'MeanTensor does not have any result yet. Please call the MeanTensor '
'MeanTensor does not have any value yet. Please call the MeanTensor '
'instance or use `.update_state(value)` before retrieving the result.'
)
return tf.math.divide_no_nan(self.total, self.count)
Expand Down Expand Up @@ -3715,7 +3732,7 @@ def get(identifier):
return identifier
else:
raise ValueError(
'Could not interpret metric function identifier: {}'.format(identifier))
f'Could not interpret metric identifier: {identifier}')


def is_built_in(cls):
Expand Down
18 changes: 12 additions & 6 deletions keras/metrics_confusion_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ def test_invalid_specificity(self):
metrics.SensitivityAtSpecificity(-1)

def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'):
with self.assertRaisesRegex(
ValueError, 'Argument `num_thresholds` must be an integer > 0'):
metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)


Expand Down Expand Up @@ -941,7 +942,8 @@ def test_invalid_sensitivity(self):
metrics.SpecificityAtSensitivity(-1)

def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'):
with self.assertRaisesRegex(
ValueError, 'Argument `num_thresholds` must be an integer > 0'):
metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)


Expand Down Expand Up @@ -1054,7 +1056,8 @@ def test_invalid_sensitivity(self):
metrics.PrecisionAtRecall(-1)

def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'):
with self.assertRaisesRegex(
ValueError, 'Argument `num_thresholds` must be an integer > 0'):
metrics.PrecisionAtRecall(0.4, num_thresholds=-1)


Expand Down Expand Up @@ -1186,7 +1189,8 @@ def test_invalid_sensitivity(self):
metrics.RecallAtPrecision(-1)

def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'):
with self.assertRaisesRegex(
ValueError, 'Argument `num_thresholds` must be an integer > 0'):
metrics.RecallAtPrecision(0.4, num_thresholds=-1)


Expand Down Expand Up @@ -1453,10 +1457,12 @@ def test_weighted_pr_interpolation(self):
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)

def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 1.'):
with self.assertRaisesRegex(
ValueError, 'Argument `num_thresholds` must be an integer > 1'):
metrics.AUC(num_thresholds=-1)

with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 1.'):
with self.assertRaisesRegex(
ValueError, 'Argument `num_thresholds` must be an integer > 1.'):
metrics.AUC(num_thresholds=1)

def test_invalid_curve(self):
Expand Down
2 changes: 1 addition & 1 deletion keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ def test_config(self):
self.assertEqual(m.dtype, tf.float32)
self.assertEmpty(m.variables)

with self.assertRaisesRegex(ValueError, 'does not have any result yet'):
with self.assertRaisesRegex(ValueError, 'does not have any value yet'):
m.result()

self.evaluate(m([[3], [5], [3]]))
Expand Down

0 comments on commit b693bb8

Please sign in to comment.