From b693bb84200d70aa736f2491ff83509fd1b1b6fb Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Thu, 29 Jul 2021 18:35:53 -0700 Subject: [PATCH] Improve error messages in Keras activations / constraints / metrics. PiperOrigin-RevId: 387713769 --- keras/activations.py | 5 +-- keras/backend_config.py | 13 ++++-- keras/constraints.py | 8 ++-- keras/metrics.py | 59 +++++++++++++++++--------- keras/metrics_confusion_matrix_test.py | 18 +++++--- keras/metrics_test.py | 2 +- 6 files changed, 67 insertions(+), 38 deletions(-) diff --git a/keras/activations.py b/keras/activations.py index 4a1de98b8636..0bc3d2d35477 100644 --- a/keras/activations.py +++ b/keras/activations.py @@ -85,7 +85,7 @@ def softmax(x, axis=-1): output = e / s else: raise ValueError('Cannot apply softmax to a tensor that is 1D. ' - 'Received input: %s' % (x,)) + f'Received input: {x}') # Cache the logits to use for crossentropy loss. output._keras_logits = x # pylint: disable=protected-access @@ -600,5 +600,4 @@ def get(identifier): return identifier else: raise TypeError( - 'Could not interpret activation function identifier: {}'.format( - identifier)) + f'Could not interpret activation function identifier: {identifier}') diff --git a/keras/backend_config.py b/keras/backend_config.py index f5ce3db29d54..a76e40b38086 100644 --- a/keras/backend_config.py +++ b/keras/backend_config.py @@ -103,8 +103,10 @@ def set_floatx(value): ValueError: In case of invalid value. """ global _FLOATX - if value not in {'float16', 'float32', 'float64'}: - raise ValueError('Unknown floatx type: ' + str(value)) + accepted_dtypes = {'float16', 'float32', 'float64'} + if value not in accepted_dtypes: + raise ValueError( + f'Unknown `floatx` value: {value}. Expected one of {accepted_dtypes}') _FLOATX = str(value) @@ -142,6 +144,9 @@ def set_image_data_format(data_format): ValueError: In case of invalid `data_format` value. """ global _IMAGE_DATA_FORMAT - if data_format not in {'channels_last', 'channels_first'}: - raise ValueError('Unknown data_format: ' + str(data_format)) + accepted_formats = {'channels_last', 'channels_first'} + if data_format not in accepted_formats: + raise ValueError( + f'Unknown `data_format`: {data_format}. ' + f'Expected one of {accepted_formats}') _IMAGE_DATA_FORMAT = str(data_format) diff --git a/keras/constraints.py b/keras/constraints.py index 70e54efac52d..c3302ab195c5 100644 --- a/keras/constraints.py +++ b/keras/constraints.py @@ -257,7 +257,8 @@ def __call__(self, w): w_shape = w.shape if w_shape.rank is None or w_shape.rank != 4: raise ValueError( - 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) + 'The weight tensor must have rank 4. ' + f'Received weight tensor with shape: {w_shape}') height, width, channels, kernels = w_shape w = backend.reshape(w, (height, width, channels * kernels)) @@ -332,6 +333,7 @@ def deserialize(config, custom_objects=None): @keras_export('keras.constraints.get') def get(identifier): + """Retrieves a Keras constraint function.""" if identifier is None: return None if isinstance(identifier, dict): @@ -342,5 +344,5 @@ def get(identifier): elif callable(identifier): return identifier else: - raise ValueError('Could not interpret constraint identifier: ' + - str(identifier)) + raise ValueError( + f'Could not interpret constraint function identifier: {identifier}') diff --git a/keras/metrics.py b/keras/metrics.py index 7e7b0c44d8ae..feb4144f08a2 100644 --- a/keras/metrics.py +++ b/keras/metrics.py @@ -394,9 +394,9 @@ def update_state(self, values, sample_weight=None): values = tf.cast(values, self._dtype) except (ValueError, TypeError): msg = ('The output of a metric function can only be a single Tensor. ' - 'Got: %s' % (values,)) + f'Received: {values}. ') if isinstance(values, dict): - msg += ('. To return a dict of values, implement a custom Metric ' + msg += ('To return a dict of values, implement a custom Metric ' 'subclass.') raise RuntimeError(msg) if sample_weight is not None: @@ -438,7 +438,8 @@ def update_state(self, values, sample_weight=None): num_values = tf.reduce_sum(sample_weight) else: raise NotImplementedError( - 'reduction [%s] not implemented' % self.reduction) + f'Reduction "{self.reduction}" not implemented. Expected ' + '"sum", "weighted_mean", or "sum_over_batch_size".') with tf.control_dependencies([update_total_op]): return self.count.assign_add(num_values) @@ -453,7 +454,8 @@ def result(self): return tf.math.divide_no_nan(self.total, self.count) else: raise NotImplementedError( - 'reduction [%s] not implemented' % self.reduction) + f'Reduction "{self.reduction}" not implemented. Expected ' + '"sum", "weighted_mean", or "sum_over_batch_size".') @keras_export('keras.metrics.Sum') @@ -1530,7 +1532,9 @@ def __init__(self, dtype=None): super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) if num_thresholds <= 0: - raise ValueError('`num_thresholds` must be > 0.') + raise ValueError( + 'Argument `num_thresholds` must be an integer > 0. ' + f'Received: num_thresholds={num_thresholds}') self.value = value self.class_id = class_id self.true_positives = self.add_weight( @@ -1689,7 +1693,9 @@ def __init__(self, name=None, dtype=None): if specificity < 0 or specificity > 1: - raise ValueError('`specificity` must be in the range [0, 1].') + raise ValueError( + 'Argument `specificity` must be in the range [0, 1]. ' + f'Received: specificity={specificity}') self.specificity = specificity self.num_thresholds = num_thresholds super(SensitivityAtSpecificity, self).__init__( @@ -1781,7 +1787,9 @@ def __init__(self, name=None, dtype=None): if sensitivity < 0 or sensitivity > 1: - raise ValueError('`sensitivity` must be in the range [0, 1].') + raise ValueError( + 'Argument `sensitivity` must be in the range [0, 1]. ' + f'Received: sensitivity={sensitivity}') self.sensitivity = sensitivity self.num_thresholds = num_thresholds super(SpecificityAtSensitivity, self).__init__( @@ -1865,7 +1873,9 @@ def __init__(self, name=None, dtype=None): if recall < 0 or recall > 1: - raise ValueError('`recall` must be in the range [0, 1].') + raise ValueError( + 'Argument `recall` must be in the range [0, 1]. ' + f'Received: recall={recall}') self.recall = recall self.num_thresholds = num_thresholds super(PrecisionAtRecall, self).__init__( @@ -1949,7 +1959,9 @@ def __init__(self, name=None, dtype=None): if precision < 0 or precision > 1: - raise ValueError('`precision` must be in the range [0, 1].') + raise ValueError( + 'Argument `precision` must be in the range [0, 1]. ' + f'Received: precision={precision}') self.precision = precision self.num_thresholds = num_thresholds super(RecallAtPrecision, self).__init__( @@ -2105,15 +2117,16 @@ def __init__(self, # Validate configurations. if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( metrics_utils.AUCCurve): - raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( - curve, list(metrics_utils.AUCCurve))) + raise ValueError( + f'Invalid `curve` argument value "{curve}". ' + f'Expected one of: {list(metrics_utils.AUCCurve)}') if isinstance( summation_method, metrics_utils.AUCSummationMethod) and summation_method not in list( metrics_utils.AUCSummationMethod): raise ValueError( - 'Invalid summation method: "{}". Valid options are: "{}"'.format( - summation_method, list(metrics_utils.AUCSummationMethod))) + f'Invalid `summation_method` argument value "{summation_method}". ' + f'Expected one of: {list(metrics_utils.AUCSummationMethod)}') # Update properties. self._init_from_thresholds = thresholds is not None @@ -2126,7 +2139,8 @@ def __init__(self, np.array([0.0] + thresholds + [1.0]))) else: if num_thresholds <= 1: - raise ValueError('`num_thresholds` must be > 1.') + raise ValueError('Argument `num_thresholds` must be an integer > 1. ' + f'Received: num_thresholds={num_thresholds}') # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in # (0, 1). @@ -2188,8 +2202,10 @@ def _build(self, shape): """Initialize TP, FP, TN, and FN tensors, given the shape of the data.""" if self.multi_label: if shape.ndims != 2: - raise ValueError('`y_true` must have rank=2 when `multi_label` is ' - 'True. Found rank %s.' % shape.ndims) + raise ValueError( + '`y_true` must have rank 2 when `multi_label=True`. ' + f'Found rank {shape.ndims}. ' + f'Full shape received for `y_true`: {shape}') self._num_labels = shape[1] variable_shape = tf.TensorShape( [tf.compat.v1.Dimension(self.num_thresholds), self._num_labels]) @@ -3136,9 +3152,10 @@ def update_state(self, values, sample_weight=None): if not self._built: self._build(values.shape) elif values.shape != self._shape: - raise ValueError('MeanTensor input values must always have the same ' - 'shape. Expected shape (set during the first call): {}. ' - 'Got: {}'.format(self._shape, values.shape)) + raise ValueError( + 'MeanTensor input values must always have the same ' + f'shape. Expected shape (set during the first call): {self._shape}. ' + f'Got: {values.shape}.') num_values = tf.ones_like(values) if sample_weight is not None: @@ -3168,7 +3185,7 @@ def update_state(self, values, sample_weight=None): def result(self): if not self._built: raise ValueError( - 'MeanTensor does not have any result yet. Please call the MeanTensor ' + 'MeanTensor does not have any value yet. Please call the MeanTensor ' 'instance or use `.update_state(value)` before retrieving the result.' ) return tf.math.divide_no_nan(self.total, self.count) @@ -3715,7 +3732,7 @@ def get(identifier): return identifier else: raise ValueError( - 'Could not interpret metric function identifier: {}'.format(identifier)) + f'Could not interpret metric identifier: {identifier}') def is_built_in(cls): diff --git a/keras/metrics_confusion_matrix_test.py b/keras/metrics_confusion_matrix_test.py index 19d48112cbc3..26c4ff5a84d3 100644 --- a/keras/metrics_confusion_matrix_test.py +++ b/keras/metrics_confusion_matrix_test.py @@ -829,7 +829,8 @@ def test_invalid_specificity(self): metrics.SensitivityAtSpecificity(-1) def test_invalid_num_thresholds(self): - with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'): + with self.assertRaisesRegex( + ValueError, 'Argument `num_thresholds` must be an integer > 0'): metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1) @@ -941,7 +942,8 @@ def test_invalid_sensitivity(self): metrics.SpecificityAtSensitivity(-1) def test_invalid_num_thresholds(self): - with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'): + with self.assertRaisesRegex( + ValueError, 'Argument `num_thresholds` must be an integer > 0'): metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1) @@ -1054,7 +1056,8 @@ def test_invalid_sensitivity(self): metrics.PrecisionAtRecall(-1) def test_invalid_num_thresholds(self): - with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'): + with self.assertRaisesRegex( + ValueError, 'Argument `num_thresholds` must be an integer > 0'): metrics.PrecisionAtRecall(0.4, num_thresholds=-1) @@ -1186,7 +1189,8 @@ def test_invalid_sensitivity(self): metrics.RecallAtPrecision(-1) def test_invalid_num_thresholds(self): - with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 0.'): + with self.assertRaisesRegex( + ValueError, 'Argument `num_thresholds` must be an integer > 0'): metrics.RecallAtPrecision(0.4, num_thresholds=-1) @@ -1453,10 +1457,12 @@ def test_weighted_pr_interpolation(self): self.assertAllClose(self.evaluate(result), expected_result, 1e-3) def test_invalid_num_thresholds(self): - with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 1.'): + with self.assertRaisesRegex( + ValueError, 'Argument `num_thresholds` must be an integer > 1'): metrics.AUC(num_thresholds=-1) - with self.assertRaisesRegex(ValueError, '`num_thresholds` must be > 1.'): + with self.assertRaisesRegex( + ValueError, 'Argument `num_thresholds` must be an integer > 1.'): metrics.AUC(num_thresholds=1) def test_invalid_curve(self): diff --git a/keras/metrics_test.py b/keras/metrics_test.py index e159428b64f9..26e0da725b23 100644 --- a/keras/metrics_test.py +++ b/keras/metrics_test.py @@ -1379,7 +1379,7 @@ def test_config(self): self.assertEqual(m.dtype, tf.float32) self.assertEmpty(m.variables) - with self.assertRaisesRegex(ValueError, 'does not have any result yet'): + with self.assertRaisesRegex(ValueError, 'does not have any value yet'): m.result() self.evaluate(m([[3], [5], [3]]))