Skip to content

Commit

Permalink
Add mixed precision support (#46)
Browse files Browse the repository at this point in the history
* Add .python-version to .gitignore for pyenv environments

* Cast ClippedBinaryCrossEntropy parameters to same type as loss for mixed precision support

* Set activation layer of LearntNorms to be of float32 dtype policy if global policy is mixed precision

* Add ClippedBinaryCrossEntropy test for all float dtypes

* Update CI config to install correct version of Tensorflow

* Remove activation in last Dense layer
  • Loading branch information
julienperichon authored Jul 29, 2022
1 parent 149dd31 commit 0f176f8
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 804 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7]
python-version: [ 3.6, 3.7 ]

steps:
- name: Set up Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ siamese_nets_classifier
.coverage
*egg-info
logs
.python-version
9 changes: 6 additions & 3 deletions keras_fsl/losses/gram_matrix_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"""
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.losses import Loss
import tensorflow_probability as tfp
from tensorflow.keras.losses import Loss


class MeanScoreClassificationLoss(Loss):
Expand Down Expand Up @@ -70,8 +70,11 @@ def __init__(self, lower=0.0, upper=1.0, **kwargs):

def call(self, y_true, y_pred):
loss = super().call(y_true, y_pred)
clip_mask = tf.math.logical_and(-tf.math.log(1 - self.lower) < loss, loss < -tf.math.log(1 - self.upper))
return tf.cast(clip_mask, dtype=y_pred.dtype) * loss
clip_mask = tf.math.logical_and(
-tf.math.log(1 - tf.cast(self.lower, dtype=loss.dtype)) < loss,
loss < -tf.math.log(1 - tf.cast(self.upper, dtype=loss.dtype)),
)
return tf.cast(clip_mask, dtype=loss.dtype) * loss


# TODO: use reduction kwarg of loss when it becomes possible to give custom reduction to includes all other reductions below in
Expand Down
9 changes: 9 additions & 0 deletions keras_fsl/losses/tests/gram_matrix_losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ def test_clipped_loss_should_equal_literal_calculation(self, y_true, adjacency_m
)
np.testing.assert_almost_equal(tf_loss, np_loss, decimal=5)

@staticmethod
@pytest.mark.parametrize("dtype_policy", (tf.float16, tf.bfloat16, tf.float32, tf.float64))
def test_clipped_loss_computes_in_all_float_dtypes(dtype_policy, y_true, y_pred):
y_true_tensor = tf.convert_to_tensor(y_true)
y_pred_tensor = tf.convert_to_tensor(y_pred)
ClippedBinaryCrossentropy(lower=0.05, upper=0.95)(
tf.cast(y_true_tensor, dtype=dtype_policy), tf.cast(y_pred_tensor, dtype=dtype_policy)
)

def test_max_loss_should_equal_literal_calculation(self, y_true, adjacency_matrix, y_pred):
np_loss = np.max(-(adjacency_matrix * np.log(y_pred) + (1 - adjacency_matrix) * np.log(1 - y_pred)))
tf_loss = MaxBinaryCrossentropy()(
Expand Down
10 changes: 9 additions & 1 deletion keras_fsl/models/head_models/learnt_norms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import activations
from tensorflow.keras.layers import (
Concatenate,
Expand All @@ -8,7 +9,9 @@
Input,
Reshape,
)
from tensorflow.keras.mixed_precision.experimental import global_policy
from tensorflow.keras.models import Model
from tensorflow.python.keras.layers import Activation


def LearntNorms(input_shape, use_bias=True, activation="sigmoid"):
Expand All @@ -31,6 +34,11 @@ def LearntNorms(input_shape, use_bias=True, activation="sigmoid"):
)
output = Conv2D(filters=1, kernel_size=(1, 1), activation="linear", name="norms_average", use_bias=use_bias)(output)
output = Flatten()(output)
output = Dense(1, name="raw_output", use_bias=use_bias)(output)

output = Dense(1, activation=activations.get(activation), name="output", use_bias=use_bias)(output)
global_dtype_policy = global_policy().name
if global_dtype_policy in ["mixed_float16", "mixed_bfloat16"]:
output = Activation(activations.get(activation), dtype=tf.float32, name="predictions")(output)
else:
output = Activation(activations.get(activation), name="predictions")(output)
return Model(inputs=inputs, outputs=output)
18 changes: 18 additions & 0 deletions keras_fsl/models/head_models/tests/learnt_norms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ def test_should_fit(self, input_shape):

learnt_norms.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)

@parameterized.named_parameters(
("mixed_float16", "mixed_float16", "float32"),
("mixed_bfloat16", "mixed_bfloat16", "float32"),
("float32", "float32", "float32"),
("float64", "float64", "float64"),
)
def test_last_activation_fp32_in_mixed_precision(self, mixed_precision_policy, expected_last_layer_dtype_policy):
policy = tf.keras.mixed_precision.experimental.Policy(mixed_precision_policy)
tf.keras.mixed_precision.experimental.set_policy(policy)
learnt_norms = LearntNorms(input_shape=(10,))

# Check dtype policy of internal non-input layers
for layer in learnt_norms.layers[2:-1]:
assert layer._dtype_policy.name == mixed_precision_policy

# Check dtype policy of last layer always at least FP32
assert learnt_norms.layers[-1]._dtype_policy.name == expected_last_layer_dtype_policy


if __name__ == "__main__":
tf.test.main()
Loading

0 comments on commit 0f176f8

Please sign in to comment.