diff --git a/ddsm_train/image_clf_train.py b/ddsm_train/image_clf_train.py index ed7ecf8..38577a5 100644 --- a/ddsm_train/image_clf_train.py +++ b/ddsm_train/image_clf_train.py @@ -89,7 +89,10 @@ def run(train_dir, val_dir, test_dir, patch_model_state=None, resume_from=None, # ============ Train & validation set =============== # train_bs = int(batch_size*train_bs_multiplier) - dup_3_channels = True + if patch_net != 'yaroslav': + dup_3_channels = True + else: + dup_3_channels = False if load_train_ram: raw_imgen = DMImageDataGenerator() print "Create generator for raw train set" diff --git a/ddsm_train/patch_clf_train.py b/ddsm_train/patch_clf_train.py index 65f55ca..717d9e4 100644 --- a/ddsm_train/patch_clf_train.py +++ b/ddsm_train/patch_clf_train.py @@ -82,7 +82,10 @@ def run(train_dir, val_dir, test_dir, # ============ Train & validation set =============== # train_bs = int(batch_size*train_bs_multiplier) - dup_3_channels = True + if net != 'yaroslav': + dup_3_channels = True + else: + dup_3_channels = False if load_train_ram: raw_imgen = DMImageDataGenerator() print "Create generator for raw train set" diff --git a/dm_keras_ext.py b/dm_keras_ext.py index eafd2a4..c0cf9c2 100644 --- a/dm_keras_ext.py +++ b/dm_keras_ext.py @@ -2,7 +2,11 @@ import numpy as np from keras.callbacks import Callback from keras.models import load_model, Model -from keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D +from keras.layers import ( + Flatten, Dense, Dropout, Input, + GlobalAveragePooling2D, Activation, + MaxPooling2D +) from keras.layers.convolutional import Conv2D from keras.regularizers import l2 from keras.optimizers import ( @@ -17,9 +21,18 @@ from keras.preprocessing.image import flip_axis import keras.backend as K data_format = K.image_data_format() +if K.image_data_format() == 'channels_last': + ROW_AXIS = 1 + COL_AXIS = 2 + CHANNEL_AXIS = 3 +else: + CHANNEL_AXIS = 1 + ROW_AXIS = 2 + COL_AXIS = 3 from sklearn.metrics import roc_auc_score from dm_resnet import ResNetBuilder from dm_multi_gpu import make_parallel +from keras.layers.normalization import BatchNormalization def flip_all_img(X): @@ -77,6 +90,86 @@ def load_dat_ram(generator, nb_samples): return data_set +def Yaroslav(input_shape=None, classes=5): + """Instantiates the Yaroslav's winning architecture for patch classifiers. + """ + if input_shape is None: + if data_format == 'channels_last': + input_shape = (None, None, 1) + else: + input_shape = (1, None, None) + img_input = Input(shape=input_shape) + + # Block 1 + x = Conv2D(32, (3, 3), padding='same', name='block1_conv1')(img_input) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = Conv2D(32, (3, 3), padding='same', name='block1_conv2')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) + + # Block 2 + x = Conv2D(64, (3, 3), padding='same', name='block2_conv1')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = Conv2D(64, (3, 3), padding='same', name='block2_conv2')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) + + # Block 3 + x = Conv2D(128, (3, 3), padding='same', name='block3_conv1')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = Conv2D(128, (3, 3), padding='same', name='block3_conv2')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) + + # Block 4 + x = Conv2D(256, (3, 3), padding='same', name='block4_conv1')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = Conv2D(256, (3, 3), padding='same', name='block4_conv2')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) + + # Block 5 + x = Conv2D(256, (3, 3), padding='same', name='block5_conv1')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = Conv2D(256, (3, 3), padding='same', name='block5_conv2')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) + + # Block 6 + x = Conv2D(512, (3, 3), padding='same', name='block6_conv1')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = Conv2D(512, (3, 3), padding='same', name='block6_conv2')(x) + x = BatchNormalization(axis=CHANNEL_AXIS)(x) + x = Activation('relu')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='block6_pool')(x) + + # Classification block + #x = Flatten(name='flatten')(x) + #x = Dense(1024, name='fc1')(x) + #x = BatchNormalization()(x) + #x = Activation('relu')(x) + #x = Dense(512, name='fc2')(x) + #x = BatchNormalization()(x) + #x = Activation('relu')(x) + x = GlobalAveragePooling2D()(x) + x = Dense(classes, activation='softmax', name='predictions')(x) + + # Create model. + model = Model(img_input, x, name='yaroslav') + return model + + def get_dl_model(net, nb_class=3, use_pretrained=True, resume_from=None, top_layer_nb=None, weight_decay=.01, hidden_dropout=.0, **kw_args): @@ -99,6 +192,9 @@ def get_dl_model(net, nb_class=3, use_pretrained=True, resume_from=None, elif net == 'inception': from keras.applications.inception_v3 import InceptionV3 as NNet, preprocess_input top_layer_nb = 194 if top_layer_nb is None else top_layer_nb + elif net == 'yaroslav': + top_layer_nb = None + preprocess_input = None else: raise Exception("Requested model is not available: " + net) weights = 'imagenet' if use_pretrained else None @@ -108,6 +204,8 @@ def get_dl_model(net, nb_class=3, use_pretrained=True, resume_from=None, sys.stdout.flush() model = load_model(resume_from) print "Done." + elif net == 'yaroslav': + model = Yaroslav(classes=nb_class) else: print "Loading %s," % (net), sys.stdout.flush() diff --git a/dm_resnet.py b/dm_resnet.py index 2905a53..4f53d2a 100644 --- a/dm_resnet.py +++ b/dm_resnet.py @@ -237,10 +237,13 @@ def add_fc_layers(block): if patch_net == 'resnet50': last_kept_layer = model.layers[-5] + elif patch_net == 'yaroslav': + last_kept_layer = model.layers[-3] else: last_kept_layer = model.layers[-4] block = last_kept_layer.output - image_input = Input(shape=(image_size[0],image_size[1],3)) + channels = 1 if patch_net == 'yaroslav' else 3 + image_input = Input(shape=(image_size[0], image_size[1], channels)) model0 = Model(inputs=model.inputs, outputs=block) block = model0(image_input) if add_heatmap or return_heatmap: # add softmax heatmap. @@ -253,9 +256,11 @@ def add_fc_layers(block): clf_layer = model.layers[-1] clf_weights = clf_layer.get_weights() clf_classes = clf_layer.output_shape[1] - def softmax(x): - return activations.softmax(x, axis=CHANNEL_AXIS) - heatmap_layer = Dense(clf_classes, activation=softmax, + if return_heatmap: + activation = activations.softmax(x, axis=CHANNEL_AXIS) + else: + activation = 'relu' + heatmap_layer = Dense(clf_classes, activation=activation, kernel_regularizer=l2(weight_decay)) heatmap = heatmap_layer(dropped) heatmap_layer.set_weights(clf_weights)