Skip to content

Commit

Permalink
WIP yolo loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Clément Walter committed Mar 6, 2020
1 parent c0d570c commit 717b1f5
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 21 deletions.
59 changes: 59 additions & 0 deletions keras_fsl/losses/yolo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import tensorflow as tf


def yolo_loss(anchors, threshold):
"""
Args:
anchors (pandas.DataFrame): dataframe of the anchors with width and height columns.
threshold:
"""
def _yolo_loss(y_true, y_pred):
"""
y_true and y_pred are (batch_size, number of boxes, 4 (+ 1) + number of classes (+ anchor_id for y_pred)).
The number of boxes is determined by the network architecture as in single-shot detection one can only predict
grid_width x grid_height boxes per anchor.
"""
# 1. Find matching anchors: the anchor with the best IoU is chosen for predicting each true box
y_true_broadcast = tf.expand_dims(y_true, axis=2)
y_true_broadcast.shape
y_true_broadcast[..., 2:4].shape

anchors_tensor = tf.broadcast_to(anchors[['height', 'width']].values, [1, 1, len(anchors), 2])
anchors_tensor.shape

height_width_min = tf.minimum(y_true_broadcast[..., 2:4], anchors_tensor)
height_width_max = tf.maximum(y_true_broadcast[..., 2:4], anchors_tensor)
height_width_min.shape
height_width_max.shape
intersection = tf.reduce_prod(height_width_min, axis=-1)
intersection.shape
true_box_area = tf.reduce_prod(y_true_broadcast[..., 2:4], axis=-1)
true_box_area.shape
anchor_boxes_area = tf.reduce_prod(anchors_tensor, axis=-1)
anchor_boxes_area.shape
union = true_box_area + anchor_boxes_area - intersection
union.shape
iou = intersection / union
iou.shape
best_anchor = tf.math.argmax(iou, axis=-1)
best_anchor.shape
best_anchor[0, 0]

batch_size, boxes, _ = tf.shape(y_true)
# 2. Find grid cell: for each selected anchor, select the prediction coming from the cell which contains the true box center
for image in range(batch_size):
for box in range(boxes):
true_box_info = y_true[image, box]
selected_anchor = tf.cast(best_anchor[image, box], y_pred.dtype)
prediction_for_anchor = tf.boolean_mask(y_pred[image], y_pred[image, :, -1] == selected_anchor, axis=0)
prediction_for_anchor.shape
grid_size = prediction_for_anchor
y_pred[..., -1].shape == best_anchor
y_pred.shape

# 3. For confidence loss: for each selected anchor, compute confidence loss for boxes with IoU < threshold
non_empty_boxes_mask = tf.cast(tf.math.reduce_prod(y_true[..., 2:4], axis=-1) > 0, tf.bool)
pass
return _yolo_loss
1 change: 1 addition & 0 deletions keras_fsl/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .feature_pyramid_net import FeaturePyramidNet
from .siamese_nets import SiameseNets
from .siamese_detector import SiameseDetector
7 changes: 7 additions & 0 deletions keras_fsl/models/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .yolo_box import YoloBox
from .yolo_coordinates import YoloCoordinates

__all__ = [
'YoloBox',
'YoloCoordinates',
]
23 changes: 23 additions & 0 deletions keras_fsl/models/activations/yolo_box.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Activation function for mapping feature into output coordinates as in Yolo V3
"""
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Lambda


def YoloBox(anchor):
"""
Activation function for the box dimension regression. Dimensions are relative to the image dimension, ie. between 0
and 1
Args:
anchor (Union[pandas.Series, collections.namedtuple]): with key width and height. Note that given a tensor with shape
(batch_size, i, j, channels), i is related to height and j to width
"""
return Sequential([
Activation('exponential'),
Lambda(lambda input_, anchor_=anchor: (
input_ * tf.convert_to_tensor([anchor_.height, anchor_.width], dtype=tf.float32)
)),
])
34 changes: 34 additions & 0 deletions keras_fsl/models/activations/yolo_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Activation function for mapping feature into output coordinates as in Yolo V3
"""
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Lambda


@tf.function
def build_grid_coordinates(grid_shape):
"""
Build a grid coordinate tensor with shape (*grid_shape, 2) where grid[i, j, 0] = i and grid[i, j, 1] = j
Args:
grid_shape (Union[tuple, list, tensorflow.TensorShape]): to be passed to tf.range
Returns:
(tensorflow.Tensor)
"""
height, width = tf.meshgrid(tf.range(0, grid_shape[0]), tf.range(0, grid_shape[1]))
width = tf.transpose(width)
height = tf.transpose(height)
return tf.stack([height, width], -1)


def YoloCoordinates():
"""
Activation function for the box center coordinates regression. Coordinates are relative to the image dimension, ie. between 0
and 1
"""
return Sequential([
Activation('sigmoid'),
Lambda(lambda input_: input_ + tf.cast(tf.expand_dims(build_grid_coordinates(tf.shape(input_)[1:3]), 0), input_.dtype)),
Lambda(lambda input_: input_ / tf.cast(tf.shape(input_)[1:3], input_.dtype)),
])
60 changes: 41 additions & 19 deletions keras_fsl/models/feature_pyramid_net.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from functools import wraps

import pandas as pd
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate

from keras_fsl.models import branch_models
from tensorflow.keras.layers import Conv2D, BatchNormalization, UpSampling2D, ReLU, Concatenate, Reshape, Lambda

from keras_fsl.models import branch_models, activations

ANCHORS = pd.DataFrame([
[0, 116, 90],
[0, 156, 198],
[0, 373, 326],
[1, 30, 61],
[1, 62, 45],
[1, 59, 119],
[2, 10, 13],
[2, 16, 30],
[2, 33, 23],
[0, 116 / 416, 90 / 416],
[0, 156 / 416, 198 / 416],
[0, 373 / 416, 326 / 416],
[1, 30 / 416, 61 / 416],
[1, 62 / 416, 45 / 416],
[1, 59 / 416, 119 / 416],
[2, 10 / 416, 13 / 416],
[2, 16 / 416, 30 / 416],
[2, 33 / 416, 23 / 416],
], columns=['scale', 'width', 'height'])


Expand All @@ -43,6 +43,13 @@ def up_sampling_block(filters, *args, **kwargs):
], *args, **kwargs)


def regression_block(activation, *args, **kwargs):
return Sequential([
Conv2D(2, (1, 1)),
getattr(activations, activation)(*args),
], **kwargs)


def FeaturePyramidNet(
backbone='MobileNet',
*args,
Expand All @@ -51,6 +58,8 @@ def FeaturePyramidNet(
anchors=None,
classes=None,
weights=None,
coordinates_activation='YoloCoordinates',
box_activation='YoloBox',
**kwargs,
):
"""
Expand All @@ -60,6 +69,12 @@ def FeaturePyramidNet(
It analyses the given backbone architecture so as to extract the features maps at relevant positions (last position before downsampling)
Then it builds a model with as many feature maps (outputs) as requested, starting from the deepest.
When classes is not None, it builds a single shot detector from the features based on a given list of anchors. In this case, all
dimensions are relative to the image dimension: coordinates and box dimensions will be float in [0, 1]. Hence anchors are defined with
floats for width and height. Anchor should also specify onto which feature map it is based: the current implementation counts backward
with 0 meaning the smallest resolution, 1 the following one, etc. The output shape of the model is then a list of boxes for each image,
ie (batch_size, number of boxes, {coordinates, (objectness,) labels, anchor_id}).
Args:
backbone (Union[str, dict, tensorflow.keras.Model]): parameters of the feature extractor
feature_maps (int): number of feature maps to extract from the backbone.
Expand All @@ -69,6 +84,8 @@ def FeaturePyramidNet(
feature map: 0 for the smallest resolution, 1 for the next one, etc.
classes (pandas.Series): provide classes to build a single-shot detector from the anchors and the feature maps.
weights (Union[str, pathlib.Path]): path to the weights file to load with tensorflow.keras.load_weights
coordinates_activation (str): activation function to be used for the center coordinates regression
box_activation (str): activation function to be used for the box height and width regression
"""
if not isinstance(backbone, Model):
if isinstance(backbone, str):
Expand All @@ -82,7 +99,7 @@ def FeaturePyramidNet(
if isinstance(layer.input_shape, list)
else layer.output_shape
for layer in backbone.layers
], columns=['batch_size', 'width', 'height', 'channels'])
], columns=['batch_size', 'height', 'width', 'channels'])
.loc[lambda df: df.width.iloc[0] % df.width == 0]
.drop_duplicates(['width', 'height'], keep='last')
.sort_index(ascending=False)
Expand All @@ -98,16 +115,21 @@ def FeaturePyramidNet(

if classes is not None:
if anchors is None:
anchors = ANCHORS.copy()
anchors = ANCHORS.copy().round(3)
anchors = anchors.assign(id=lambda df: 'scale_' + df.scale.astype(str) + '_' + df.width.astype(str) + 'x' + df.height.astype(str))
outputs = [
Concatenate(axis=3)(
[Conv2D(4, (1, 1), name=f'{anchor.id}_box')(outputs[anchor.scale])] +
([Conv2D(1, (1, 1), name=f'{anchor.id}_objectness')(outputs[anchor.scale])] if objectness else []) +
[Conv2D(1, (1, 1), name=f'{anchor.id}_{label}')(outputs[anchor.scale]) for label in classes],
)
Reshape((-1, 4 + int(objectness) + len(classes)))(Concatenate(axis=3, name=f'anchor_{anchor.id}_output')(
[regression_block(coordinates_activation, name=f'{anchor.id}_box_yx')(outputs[anchor.scale])] +
[regression_block(box_activation, anchor, name=f'{anchor.id}_box_hw')(outputs[anchor.scale])] +
([Conv2D(1, (1, 1), name=f'{anchor.id}_objectness', activation='sigmoid')(outputs[anchor.scale])] if objectness else []) +
[Conv2D(1, (1, 1), name=f'{anchor.id}_{label}', activation='sigmoid')(outputs[anchor.scale]) for label in classes]
))
for anchor in anchors.itertuples()
]
outputs = Concatenate(axis=1)([
Lambda(lambda output: tf.concat([output, tf.expand_dims(tf.ones(tf.shape(output)[:2], dtype=output.dtype) * index, -1)], axis=-1))(outputs[index])
for index, anchor in anchors.iterrows()
])

model = Model(backbone.input, outputs, *args, **kwargs)
if weights is not None:
Expand Down
2 changes: 0 additions & 2 deletions keras_fsl/models/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from tensorflow.keras.layers import *

from .classification import Classification
from .gram_matrix import GramMatrix
from .slicing import CenterSlicing2D
2 changes: 2 additions & 0 deletions keras_fsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import wraps
from unittest.mock import patch

import tensorflow as tf


def patch_len(fit_generator):
"""
Expand Down

0 comments on commit 717b1f5

Please sign in to comment.