diff --git a/keras_fsl/losses/yolo_loss.py b/keras_fsl/losses/yolo_loss.py new file mode 100644 index 0000000..b311653 --- /dev/null +++ b/keras_fsl/losses/yolo_loss.py @@ -0,0 +1,61 @@ +import tensorflow as tf + + +def yolo_loss(anchors, threshold): + """ + + Args: + anchors (pandas.DataFrame): dataframe of the anchors with width and height columns. + threshold: + + """ + + def _yolo_loss(y_true, y_pred): + """ + y_true and y_pred are (batch_size, number of boxes, 4 (+ 1) + number of classes (+ anchor_id for y_pred)). + The number of boxes is determined by the network architecture as in single-shot detection one can only predict + grid_width x grid_height boxes per anchor. + """ + # 1. Find matching anchors: the anchor with the best IoU is chosen for predicting each true box + y_true_broadcast = tf.expand_dims(y_true, axis=2) + y_true_broadcast.shape + y_true_broadcast[..., 2:4].shape + + anchors_tensor = tf.broadcast_to(anchors[["height", "width"]].values, [1, 1, len(anchors), 2]) + anchors_tensor.shape + + height_width_min = tf.minimum(y_true_broadcast[..., 2:4], anchors_tensor) + height_width_max = tf.maximum(y_true_broadcast[..., 2:4], anchors_tensor) + height_width_min.shape + height_width_max.shape + intersection = tf.reduce_prod(height_width_min, axis=-1) + intersection.shape + true_box_area = tf.reduce_prod(y_true_broadcast[..., 2:4], axis=-1) + true_box_area.shape + anchor_boxes_area = tf.reduce_prod(anchors_tensor, axis=-1) + anchor_boxes_area.shape + union = true_box_area + anchor_boxes_area - intersection + union.shape + iou = intersection / union + iou.shape + best_anchor = tf.math.argmax(iou, axis=-1) + best_anchor.shape + best_anchor[0, 0] + + batch_size, boxes, _ = tf.shape(y_true) + # 2. Find grid cell: for each selected anchor, select the prediction coming from the cell which contains the true box center + for image in range(batch_size): + for box in range(boxes): + true_box_info = y_true[image, box] + selected_anchor = tf.cast(best_anchor[image, box], y_pred.dtype) + prediction_for_anchor = tf.boolean_mask(y_pred[image], y_pred[image, :, -1] == selected_anchor, axis=0) + prediction_for_anchor.shape + grid_size = prediction_for_anchor + y_pred[..., -1].shape == best_anchor + y_pred.shape + + # 3. For confidence loss: for each selected anchor, compute confidence loss for boxes with IoU < threshold + non_empty_boxes_mask = tf.cast(tf.math.reduce_prod(y_true[..., 2:4], axis=-1) > 0, tf.bool) + pass + + return _yolo_loss diff --git a/keras_fsl/models/__init__.py b/keras_fsl/models/__init__.py index bba21f2..65b244d 100644 --- a/keras_fsl/models/__init__.py +++ b/keras_fsl/models/__init__.py @@ -1,3 +1,4 @@ +from .feature_pyramid_net import FeaturePyramidNet from .siamese_nets import SiameseNets __all__ = ["SiameseNets"] diff --git a/keras_fsl/models/activations/__init__.py b/keras_fsl/models/activations/__init__.py new file mode 100644 index 0000000..c2199f8 --- /dev/null +++ b/keras_fsl/models/activations/__init__.py @@ -0,0 +1,7 @@ +from .yolo_box import YoloBox +from .yolo_coordinates import YoloCoordinates + +__all__ = [ + "YoloBox", + "YoloCoordinates", +] diff --git a/keras_fsl/models/activations/yolo_box.py b/keras_fsl/models/activations/yolo_box.py new file mode 100644 index 0000000..0f34b5d --- /dev/null +++ b/keras_fsl/models/activations/yolo_box.py @@ -0,0 +1,27 @@ +""" +Activation function for mapping feature into output box dimensions as in Yolo V3 +""" +import tensorflow as tf +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Activation, Lambda + + +def YoloBox(anchor): + """ + Activation function for the box dimension regression. Dimensions are relative to the image dimension, ie. between 0 + and 1 + + Args: + anchor (Union[pandas.Series, collections.namedtuple]): with key width and height. Note that given a tensor with shape + (batch_size, i, j, channels), i is related to height and j to width + """ + return Sequential( + [ + Activation("exponential"), + Lambda( + lambda input_, anchor_=anchor: ( + input_ * tf.convert_to_tensor([anchor_.height, anchor_.width], dtype=tf.float32) + ) + ), + ] + ) diff --git a/keras_fsl/models/activations/yolo_coordinates.py b/keras_fsl/models/activations/yolo_coordinates.py new file mode 100644 index 0000000..8325be9 --- /dev/null +++ b/keras_fsl/models/activations/yolo_coordinates.py @@ -0,0 +1,38 @@ +""" +Activation function for mapping feature into output coordinates as in Yolo V3 +""" +import tensorflow as tf +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Activation, Lambda + + +@tf.function +def build_grid_coordinates(grid_shape): + """ + Build a grid coordinate tensor with shape (*grid_shape, 2) where grid[i, j, 0] = i and grid[i, j, 1] = j + Args: + grid_shape (Union[tuple, list, tensorflow.TensorShape]): to be passed to tf.range + + Returns: + (tensorflow.Tensor) + """ + height, width = tf.meshgrid(tf.range(0, grid_shape[0]), tf.range(0, grid_shape[1])) + width = tf.transpose(width) + height = tf.transpose(height) + return tf.stack([height, width], -1) + + +def YoloCoordinates(): + """ + Activation function for the box center coordinates regression. Coordinates are relative to the image dimension, ie. between 0 + and 1 + """ + return Sequential( + [ + Activation("sigmoid"), + Lambda( + lambda input_: input_ + tf.cast(tf.expand_dims(build_grid_coordinates(tf.shape(input_)[1:3]), 0), input_.dtype) + ), + Lambda(lambda input_: input_ / tf.cast(tf.shape(input_)[1:3], input_.dtype)), + ] + ) diff --git a/keras_fsl/models/encoders/darknet.py b/keras_fsl/models/encoders/darknet.py index 3924e32..93401a6 100644 --- a/keras_fsl/models/encoders/darknet.py +++ b/keras_fsl/models/encoders/darknet.py @@ -10,12 +10,9 @@ def conv_2d(*args, **kwargs): return Conv2D(*args, **kwargs, kernel_regularizer=l2(5e-4), padding="valid" if kwargs.get("strides") == (2, 2) else "same") +@wraps(Conv2D) def conv_block(*args, **kwargs): - layer = Sequential() - layer.add(conv_2d(*args, **kwargs, use_bias=False)) - layer.add(BatchNormalization()) - layer.add(LeakyReLU(alpha=0.1)) - return layer + return Sequential([conv_2d(*args, **kwargs, use_bias=False), BatchNormalization(), LeakyReLU(alpha=0.1),]) def residual_block(input_shape, num_filters, num_blocks): diff --git a/keras_fsl/models/feature_pyramid_net.py b/keras_fsl/models/feature_pyramid_net.py new file mode 100644 index 0000000..bf0bc83 --- /dev/null +++ b/keras_fsl/models/feature_pyramid_net.py @@ -0,0 +1,145 @@ +from functools import wraps + +import pandas as pd +import tensorflow as tf +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate, Reshape, Lambda + +from keras_fsl.models import encoders, activations + +ANCHORS = pd.DataFrame( + [ + [0, 116 / 416, 90 / 416], + [0, 156 / 416, 198 / 416], + [0, 373 / 416, 326 / 416], + [1, 30 / 416, 61 / 416], + [1, 62 / 416, 45 / 416], + [1, 59 / 416, 119 / 416], + [2, 10 / 416, 13 / 416], + [2, 16 / 416, 30 / 416], + [2, 33 / 416, 23 / 416], + ], + columns=["scale", "width", "height"], +) + + +@wraps(Conv2D) +def conv_block(*args, **kwargs): + return Sequential([Conv2D(*args, **kwargs, use_bias=False), BatchNormalization(), ReLU()]) + + +def bottleneck(filters, *args, **kwargs): + return Sequential( + [conv_block(filters // 4, (1, 1), padding="same"), conv_block(filters, (3, 3), padding="same")], *args, **kwargs + ) + + +def up_sampling_block(filters, *args, **kwargs): + return Sequential([conv_block(filters, (1, 1), padding="same"), UpSampling2D(2)], *args, **kwargs) + + +def regression_block(activation, *args, **kwargs): + return Sequential([Conv2D(2, (1, 1)), getattr(activations, activation)(*args)], **kwargs) + + +def FeaturePyramidNet( + backbone="MobileNet", + *args, + feature_maps=3, + objectness=True, + anchors=None, + classes=None, + weights=None, + coordinates_activation="YoloCoordinates", + box_activation="YoloBox", + **kwargs, +): + """ + Multi scale feature extractor following the [Feature Pyramid Network for Object Detection](https://arxiv.org/pdf/1612.03144.pdf) + framework. + + It analyses the given backbone architecture so as to extract the features maps at relevant positions (last position before downsampling) + Then it builds a model with as many feature maps (outputs) as requested, starting from the deepest. + + When classes is not None, it builds a single shot detector from the features based on a given list of anchors. In this case, all + dimensions are relative to the image dimension: coordinates and box dimensions will be float in [0, 1]. Hence anchors are defined with + floats for width and height. Anchor should also specify onto which feature map it is based: the current implementation counts backward + with 0 meaning the smallest resolution, 1 the following one, etc. The output shape of the model is then a list of boxes for each image, + ie (batch_size, number of boxes, {coordinates, (objectness,) labels, anchor_id}). + + Args: + backbone (Union[str, dict, tensorflow.keras.Model]): parameters of the feature extractor + feature_maps (int): number of feature maps to extract from the backbone. + objectness (bool): whether to add a score for object presence probability or not (similar to add a background class, see Yolo for + instance). + anchors (pandas.DataFrame): containing scale, width and height columns. Scale column will be used to select the corresponding + feature map: 0 for the smallest resolution, 1 for the next one, etc. + classes (pandas.Series): provide classes to build a single-shot detector from the anchors and the feature maps. + weights (Union[str, pathlib.Path]): path to the weights file to load with tensorflow.keras.load_weights + coordinates_activation (str): activation function to be used for the center coordinates regression + box_activation (str): activation function to be used for the box height and width regression + """ + if not isinstance(backbone, Model): + if isinstance(backbone, str): + backbone = {"name": backbone, "init": {"include_top": False, "input_shape": (416, 416, 3)}} + backbone_name = backbone["name"] + backbone = getattr(encoders, backbone_name)(**backbone.get("init", {})) + + output_shapes = ( + pd.DataFrame( + [layer.input_shape[0] if isinstance(layer.input_shape, list) else layer.output_shape for layer in backbone.layers], + columns=["batch_size", "height", "width", "channels"], + ) + .loc[lambda df: df.width.iloc[0] % df.width == 0] + .drop_duplicates(["width", "height"], keep="last") + .sort_index(ascending=False) + ) + + outputs = [] + for output_shape in output_shapes.iloc[:feature_maps].itertuples(): + input_ = backbone.layers[output_shape.Index].output + if outputs: + pyramid_input = up_sampling_block(output_shape.channels, name=f"up_sampling_{output_shape.channels}")(outputs[-1]) + input_ = Concatenate()([input_, pyramid_input]) + outputs += [bottleneck(output_shape.channels, name=f"bottleneck_{output_shape.channels}")(input_)] + + if classes is not None: + if anchors is None: + anchors = ANCHORS.copy().round(3) + anchors = anchors.assign( + id=lambda df: "scale_" + df.scale.astype(str) + "_" + df.width.astype(str) + "x" + df.height.astype(str) + ) + outputs = [ + Reshape((-1, 4 + int(objectness) + len(classes)))( + Concatenate(axis=3, name=f"anchor_{anchor.id}_output")( + [regression_block(coordinates_activation, name=f"{anchor.id}_box_yx")(outputs[anchor.scale])] + + [regression_block(box_activation, anchor, name=f"{anchor.id}_box_hw")(outputs[anchor.scale])] + + ( + [Conv2D(1, (1, 1), name=f"{anchor.id}_objectness", activation="sigmoid")(outputs[anchor.scale])] + if objectness + else [] + ) + + [ + Conv2D(1, (1, 1), name=f"{anchor.id}_{label}", activation="sigmoid")(outputs[anchor.scale]) + for label in classes + ] + ) + ) + for anchor in anchors.itertuples() + ] + outputs = Concatenate(axis=1)( + [ + Lambda( + lambda output: tf.concat( + [output, tf.expand_dims(tf.ones(tf.shape(output)[:2], dtype=output.dtype) * index, -1)], axis=-1 + ) + )(outputs[index]) + for index, anchor in anchors.iterrows() + ] + ) + + model = Model(backbone.input, outputs, *args, **kwargs) + if weights is not None: + model.load_weights(weights) + + return model diff --git a/keras_fsl/models/layers/__init__.py b/keras_fsl/models/layers/__init__.py new file mode 100644 index 0000000..423a131 --- /dev/null +++ b/keras_fsl/models/layers/__init__.py @@ -0,0 +1,3 @@ +from .classification import Classification +from .gram_matrix import GramMatrix +from .slicing import CenterSlicing2D diff --git a/keras_fsl/utils/training.py b/keras_fsl/utils/training.py index d012d33..496327f 100644 --- a/keras_fsl/utils/training.py +++ b/keras_fsl/utils/training.py @@ -2,6 +2,8 @@ from functools import reduce, wraps from unittest.mock import patch +import tensorflow as tf + def patch_len(fit_generator): """