Skip to content

Commit

Permalink
Update for arXiv manuscript v2
Browse files Browse the repository at this point in the history
* Replace softmax with relu for heatmap activation.
* Add an implementation of Yaroslav's method.
  • Loading branch information
lishen committed Oct 6, 2017
2 parents 3fbdb8e + d712b81 commit c6a69fa
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 7 deletions.
5 changes: 4 additions & 1 deletion ddsm_train/image_clf_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def run(train_dir, val_dir, test_dir, patch_model_state=None, resume_from=None,

# ============ Train & validation set =============== #
train_bs = int(batch_size*train_bs_multiplier)
dup_3_channels = True
if patch_net != 'yaroslav':
dup_3_channels = True
else:
dup_3_channels = False
if load_train_ram:
raw_imgen = DMImageDataGenerator()
print "Create generator for raw train set"
Expand Down
5 changes: 4 additions & 1 deletion ddsm_train/patch_clf_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def run(train_dir, val_dir, test_dir,

# ============ Train & validation set =============== #
train_bs = int(batch_size*train_bs_multiplier)
dup_3_channels = True
if net != 'yaroslav':
dup_3_channels = True
else:
dup_3_channels = False
if load_train_ram:
raw_imgen = DMImageDataGenerator()
print "Create generator for raw train set"
Expand Down
100 changes: 99 additions & 1 deletion dm_keras_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import numpy as np
from keras.callbacks import Callback
from keras.models import load_model, Model
from keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from keras.layers import (
Flatten, Dense, Dropout, Input,
GlobalAveragePooling2D, Activation,
MaxPooling2D
)
from keras.layers.convolutional import Conv2D
from keras.regularizers import l2
from keras.optimizers import (
Expand All @@ -17,9 +21,18 @@
from keras.preprocessing.image import flip_axis
import keras.backend as K
data_format = K.image_data_format()
if K.image_data_format() == 'channels_last':
ROW_AXIS = 1
COL_AXIS = 2
CHANNEL_AXIS = 3
else:
CHANNEL_AXIS = 1
ROW_AXIS = 2
COL_AXIS = 3
from sklearn.metrics import roc_auc_score
from dm_resnet import ResNetBuilder
from dm_multi_gpu import make_parallel
from keras.layers.normalization import BatchNormalization


def flip_all_img(X):
Expand Down Expand Up @@ -77,6 +90,86 @@ def load_dat_ram(generator, nb_samples):
return data_set


def Yaroslav(input_shape=None, classes=5):
"""Instantiates the Yaroslav's winning architecture for patch classifiers.
"""
if input_shape is None:
if data_format == 'channels_last':
input_shape = (None, None, 1)
else:
input_shape = (1, None, None)
img_input = Input(shape=input_shape)

# Block 1
x = Conv2D(32, (3, 3), padding='same', name='block1_conv1')(img_input)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = Conv2D(32, (3, 3), padding='same', name='block1_conv2')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

# Block 2
x = Conv2D(64, (3, 3), padding='same', name='block2_conv1')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), padding='same', name='block2_conv2')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)

# Block 3
x = Conv2D(128, (3, 3), padding='same', name='block3_conv1')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = Conv2D(128, (3, 3), padding='same', name='block3_conv2')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

# Block 4
x = Conv2D(256, (3, 3), padding='same', name='block4_conv1')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same', name='block4_conv2')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

# Block 5
x = Conv2D(256, (3, 3), padding='same', name='block5_conv1')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same', name='block5_conv2')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)

# Block 6
x = Conv2D(512, (3, 3), padding='same', name='block6_conv1')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = Conv2D(512, (3, 3), padding='same', name='block6_conv2')(x)
x = BatchNormalization(axis=CHANNEL_AXIS)(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block6_pool')(x)

# Classification block
#x = Flatten(name='flatten')(x)
#x = Dense(1024, name='fc1')(x)
#x = BatchNormalization()(x)
#x = Activation('relu')(x)
#x = Dense(512, name='fc2')(x)
#x = BatchNormalization()(x)
#x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
x = Dense(classes, activation='softmax', name='predictions')(x)

# Create model.
model = Model(img_input, x, name='yaroslav')
return model


def get_dl_model(net, nb_class=3, use_pretrained=True, resume_from=None,
top_layer_nb=None, weight_decay=.01,
hidden_dropout=.0, **kw_args):
Expand All @@ -99,6 +192,9 @@ def get_dl_model(net, nb_class=3, use_pretrained=True, resume_from=None,
elif net == 'inception':
from keras.applications.inception_v3 import InceptionV3 as NNet, preprocess_input
top_layer_nb = 194 if top_layer_nb is None else top_layer_nb
elif net == 'yaroslav':
top_layer_nb = None
preprocess_input = None
else:
raise Exception("Requested model is not available: " + net)
weights = 'imagenet' if use_pretrained else None
Expand All @@ -108,6 +204,8 @@ def get_dl_model(net, nb_class=3, use_pretrained=True, resume_from=None,
sys.stdout.flush()
model = load_model(resume_from)
print "Done."
elif net == 'yaroslav':
model = Yaroslav(classes=nb_class)
else:
print "Loading %s," % (net),
sys.stdout.flush()
Expand Down
13 changes: 9 additions & 4 deletions dm_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,13 @@ def add_fc_layers(block):

if patch_net == 'resnet50':
last_kept_layer = model.layers[-5]
elif patch_net == 'yaroslav':
last_kept_layer = model.layers[-3]
else:
last_kept_layer = model.layers[-4]
block = last_kept_layer.output
image_input = Input(shape=(image_size[0],image_size[1],3))
channels = 1 if patch_net == 'yaroslav' else 3
image_input = Input(shape=(image_size[0], image_size[1], channels))
model0 = Model(inputs=model.inputs, outputs=block)
block = model0(image_input)
if add_heatmap or return_heatmap: # add softmax heatmap.
Expand All @@ -253,9 +256,11 @@ def add_fc_layers(block):
clf_layer = model.layers[-1]
clf_weights = clf_layer.get_weights()
clf_classes = clf_layer.output_shape[1]
def softmax(x):
return activations.softmax(x, axis=CHANNEL_AXIS)
heatmap_layer = Dense(clf_classes, activation=softmax,
if return_heatmap:
activation = activations.softmax(x, axis=CHANNEL_AXIS)
else:
activation = 'relu'
heatmap_layer = Dense(clf_classes, activation=activation,
kernel_regularizer=l2(weight_decay))
heatmap = heatmap_layer(dropped)
heatmap_layer.set_weights(clf_weights)
Expand Down

0 comments on commit c6a69fa

Please sign in to comment.