-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom_forest_pixel_classification.py
161 lines (139 loc) · 5.56 KB
/
random_forest_pixel_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# %% IMPORTS
from sklearn.ensemble import RandomForestClassifier
from skimage.io import imread, imshow
import numpy as np
import napari
from skimage import data, filters
import matplotlib.pyplot as plt
# %% LOAD AND ANNOTATE THE IMAGE
# Load the cells3D image:
image = data.cells3d()
# extract the nuclei channel and z-project:
image_2D = image[:, 1, ...].max(axis=0)
# interactively label the nuclei:
# start napari and add image:
viewer = napari.Viewer()
viewer.add_image(image_2D)
# add an empty labels layer:
labels = viewer.add_labels(np.zeros(image_2D.shape).astype(int))
"""
In Napari, we can use the labels layer to interactively label the nuclei. Label
some nuclei (label 2) and background pixels (label 1) in the label layer.
When you're done, execute the next cell.
"""
# %% VIEW ANNONTATIONS
# take a screenshot of the annotation:
napari.utils.nbscreenshot(viewer)
# retrieve the annotations from the napari layer:
annotations = labels.data
# plot the original image and the annotations side-by-side in a subplot:
fig, axes = plt.subplots(ncols=2, figsize=(8, 4))
axes[0].imshow(image_2D)
axes[0].set_title('Original image')
axes[1].imshow(annotations)
axes[1].set_title('Annotations')
plt.show()
# %% GENERATE IMAGE FEATURE STACK
"""
We now increase the number of features, that we can take out of our image. Our image
will become a 3D stack of 2D images, where each 2D image is a feature, consisting of
* the original pixel value
* the pixel value after a Gaussian blur (=denoising)
* the pixel value of the Gaussian blurred image processed through a Sobel operator (=edge detection)
"""
def generate_feature_stack(image):
# determine features
blurred = filters.gaussian(image, sigma=2)
edges = filters.sobel(blurred)
"""
Collect features in a stack. The ravel() function turns a nD image into
a 1-D image. We need to use it because scikit-learn expects values in a
1-D format here.
"""
feature_stack = [image.ravel(),
blurred.ravel(),
edges.ravel()]
return np.asarray(feature_stack)
feature_stack = generate_feature_stack(image_2D)
# show feature images:
fig, ax = plt.subplots(1, 3, figsize=(10,10))
# reshape(image.shape) is the opposite of ravel() here. We just need it for visualization.
ax[0].imshow(feature_stack[0].reshape(image_2D.shape), cmap=plt.cm.gray)
ax[0].set_title('Original image')
ax[1].imshow(feature_stack[1].reshape(image_2D.shape), cmap=plt.cm.gray)
ax[1].set_title('Blurred image')
ax[2].imshow(feature_stack[2].reshape(image_2D.shape), cmap=plt.cm.gray)
ax[2].set_title('Edges')
plt.show()
# %% FORMATTING DATA
"""We now need to format the input data so that it fits to what scikit learn expects.
Scikit-learn asks for an array of shape (n, m) as input data and (n) annotations.
n corresponds to number of pixels and m to number of features. In our case m = 3.
"""
def format_data(feature_stack, annotation):
# reformat the data to match what scikit-learn expects
# transpose the feature stack
X = feature_stack.T
# make the annotation 1-dimensional
y = annotation.ravel()
# remove all pixels from the feature and annotations which have not been annotated
mask = y > 0
X = X[mask]
y = y[mask]
return X, y
X, y = format_data(feature_stack, annotations)
print("input shape", X.shape)
print("annotation shape", y.shape)
# %% TRAIN AND PREDICT WITH RANDOM FOREST CLASSIFIER
classifier = RandomForestClassifier(max_depth=10, random_state=0)
classifier.fit(X, y)
RandomForestClassifier(max_depth=2, random_state=0)
result = classifier.predict(feature_stack.T) - 1 # we subtract 1 to make background = 0
result_2d = result.reshape(image_2D.shape)
imshow(result.reshape(image_2D.shape))
viewer.add_labels(result_2d)
napari.utils.nbscreenshot(viewer)
# %% SECOND EXAMPLE (SKIN): LOAD AND ANNOTATE
image_2d_2 = data.skin()[:,:,0]
viewer = napari.Viewer()
viewer.add_image(image_2d_2)
# add an empty labels layer:
labels = viewer.add_labels(np.zeros(image_2d_2.shape).astype(int))
# %% SECOND EXAMPLE (SKIN): VIEW ANNONTATIONS
# take a screenshot of the annotation:
napari.utils.nbscreenshot(viewer)
# retrieve the annotations from the napari layer:
annotations = labels.data
# plot the original image and the annotations side-by-side in a subplot:
fig, axes = plt.subplots(ncols=2, figsize=(8, 4))
axes[0].imshow(image_2d_2)
axes[0].set_title('Original image')
axes[1].imshow(annotations)
axes[1].set_title('Annotations')
plt.show()
# %% GENERATE IMAGE FEATURE STACK
feature_stack = generate_feature_stack(image_2d_2)
# show feature images:
fig, ax = plt.subplots(1, 3, figsize=(10,10))
# reshape(image.shape) is the opposite of ravel() here. We just need it for visualization.
ax[0].imshow(feature_stack[0].reshape(image_2d_2.shape), cmap=plt.cm.gray)
ax[0].set_title('Original image')
ax[1].imshow(feature_stack[1].reshape(image_2d_2.shape), cmap=plt.cm.gray)
ax[1].set_title('Blurred image')
ax[2].imshow(feature_stack[2].reshape(image_2d_2.shape), cmap=plt.cm.gray)
ax[2].set_title('Edges')
plt.show()
# %% FORMATTING DATA
X, y = format_data(feature_stack, annotations)
print("input shape", X.shape)
print("annotation shape", y.shape)
# %% TRAIN AND PREDICT WITH RANDOM FOREST CLASSIFIER
classifier = RandomForestClassifier(max_depth=10, random_state=0, max_samples=0.05, n_estimators=50)
classifier.fit(X, y)
RandomForestClassifier(max_depth=2, random_state=0)
result = classifier.predict(feature_stack.T) - 1 # we subtract 1 to make background = 0
result_2d = result.reshape(image_2d_2.shape)
imshow(result.reshape(image_2d_2.shape))
viewer.add_labels(result_2d)
napari.utils.nbscreenshot(viewer)
# %% END