Skip to content

Commit

Permalink
Update basnet segmentation example
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 23, 2024
1 parent 9fdad44 commit 3117146
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 120 deletions.
107 changes: 56 additions & 51 deletions examples/vision/basnet_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
structures common to real-world images in both foreground and background.
"""

"""shell
wget http://saliencydetection.net/duts/download/DUTS-TE.zip
unzip -q DUTS-TE.zip
"""

import os

# Because of the use of tf.image.ssim in the loss,
# this example requires TensorFlow. The rest of the code
# is backend-agnostic.
os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
from glob import glob
import matplotlib.pyplot as plt
Expand All @@ -49,6 +48,8 @@
import keras
from keras import layers, ops

keras.config.disable_traceback_filtering()

"""
## Define Hyperparameters
"""
Expand All @@ -57,15 +58,20 @@
BATCH_SIZE = 4
OUT_CLASSES = 1
TRAIN_SPLIT_RATIO = 0.90
DATA_DIR = "./DUTS-TE/"

"""
## Create `PyDataset`s
## Create `PyDataset`s
We will use `load_paths()` to load and split 140 paths into train and validation set, and
convert paths into `PyDataset` object.
"""

data_dir = keras.utils.get_file(
origin="http://saliencydetection.net/duts/download/DUTS-TE.zip",
extract=True,
)
data_dir = os.path.join(data_dir, "DUTS-TE")


def load_paths(path, split_ratio):
images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]
Expand Down Expand Up @@ -103,7 +109,9 @@ def __getitem__(self, idx):
batch_x, batch_y = [], []
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
x, y = self.preprocess(
self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes
self.image_paths[i],
self.mask_paths[i],
self.img_size,
)
batch_x.append(x)
batch_y.append(y)
Expand All @@ -117,13 +125,13 @@ def read_image(self, path, size, mode):
x = (x / 255.0).astype(np.float32)
return x

def preprocess(self, x_batch, y_batch, img_size, out_classes):
def preprocess(self, x_batch, y_batch, img_size):
images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
return images, masks


train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)
train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)

train_dataset = Dataset(
train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
Expand All @@ -148,8 +156,9 @@ def display(display_list):
plt.show()


for (image, mask), _ in zip(val_dataset, range(1)):
for image, mask in val_dataset:
display([image[0], mask[0]])
break

"""
## Analyze Mask
Expand Down Expand Up @@ -343,52 +352,37 @@ def basnet_rrm(base_model, out_classes):
# ------------- refined = coarse + residual
x = layers.Add()([x_input, x]) # Add prediction + refinement output

return keras.models.Model(inputs=[base_model.input], outputs=[x])
return keras.models.Model(inputs=base_model.input[0], outputs=x)


"""
## Combine Predict and Refinement Module
"""


def basnet(input_shape, out_classes):
"""BASNet, it's a combination of two modules
Prediction Module and Residual Refinement Module(RRM)."""

# Prediction model.
predict_model = basnet_predict(input_shape, out_classes)
# Refinement model.
refine_model = basnet_rrm(predict_model, out_classes)

output = refine_model.outputs # Combine outputs.
output.extend(predict_model.output)

output = [layers.Activation("sigmoid")(_) for _ in output] # Activations.

return keras.models.Model(inputs=[predict_model.input], outputs=output)


"""
## Hybrid Loss
class BASNet(keras.Model):
def __init__(self, input_shape, out_classes):
"""BASNet, it's a combination of two modules
Prediction Module and Residual Refinement Module(RRM)."""

Another important feature of BASNet is its hybrid loss function, which is a combination of
binary cross entropy, structural similarity and intersection-over-union losses, which guide
the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
"""
# Prediction model.
predict_model = basnet_predict(input_shape, out_classes)
# Refinement model.
refine_model = basnet_rrm(predict_model, out_classes)

output = refine_model.outputs # Combine outputs.
output.extend(predict_model.output)

class BasnetLoss(keras.losses.Loss):
"""BASNet hybrid loss."""
# Activations.
output = [layers.Activation("sigmoid")(x) for x in output]
super().__init__(inputs=predict_model.input[0], outputs=output)

def __init__(self, **kwargs):
super().__init__(name="basnet_loss", **kwargs)
self.smooth = 1.0e-9

# Binary Cross Entropy loss.
self.cross_entropy_loss = keras.losses.BinaryCrossentropy()
# Structural Similarity Index value.
self.ssim_value = tf.image.ssim
# Jaccard / IoU loss.
# Jaccard / IoU loss.
self.iou_value = self.calculate_iou

def calculate_iou(
Expand All @@ -402,28 +396,39 @@ def calculate_iou(
union = union - intersection
return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)

def call(self, y_true, y_pred):
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)
def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
total = 0.0
for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output
cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)

ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)

iou_value = self.iou_value(y_true, y_pred)
iou_loss = 1 - iou_value

ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
# Add all three losses.
total += cross_entropy_loss + ssim_loss + iou_loss
return total

iou_value = self.iou_value(y_true, y_pred)
iou_loss = 1 - iou_value

# Add all three losses.
return cross_entropy_loss + ssim_loss + iou_loss
"""
## Hybrid Loss
Another important feature of BASNet is its hybrid loss function, which is a combination of
binary cross entropy, structural similarity and intersection-over-union losses, which guide
the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
"""


basnet_model = basnet(
basnet_model = BASNet(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES
) # Create model.
basnet_model.summary() # Show model summary.

optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)
# Compile model.
basnet_model.compile(
loss=BasnetLoss(),
optimizer=optimizer,
metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
)
Expand Down
Loading

0 comments on commit 3117146

Please sign in to comment.