-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstain_norm.py
62 lines (54 loc) · 2.22 KB
/
stain_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Test stain normalization using stain tools
import staintools
import pandas as pd
import os
import pdb
import numpy as np
import random
from PIL import Image
import matplotlib.pyplot as plt
random.seed(0)
np.random.seed(0)
plot_thumnail = True # whether to plot a thumbnail for the original and normalized images
# Path to original images need to be normalized
patch_path ="./sample_images"
original_images = os.listdir(patch_path)
# Target image for normalization
target_image_path = "./ref_image/ref_image.tiff"
target = staintools.read_image(os.path.join(target_image_path))
target = staintools.LuminosityStandardizer.standardize(target) # Standardize brightness (optional, can improve the tissue mask calculation)
normalizer = staintools.StainNormalizer(method='vahadane')
normalizer.fit(target)
save_path = "./output"
os.makedirs(save_path, exist_ok = True)
if plot_thumnail:
thumb_path = os.path.join(save_path, "thumbnails")
os.makedirs(thumb_path, exist_ok = True)
for image in original_images:
print(f"Processing {image}")
if not os.path.exists(os.path.join(patch_path, image)):
print(f"Missing {image}")
continue
if os.path.exists(os.path.join(save_path, image)):
print(f"{image} has been processed, skipping")
continue
try:
original = staintools.read_image(os.path.join(patch_path, image))
br_standard = staintools.LuminosityStandardizer.standardize(original)
transformed = normalizer.transform(br_standard)
save_im = Image.fromarray(transformed)
save_im.save(os.path.join(save_path, image))
# Save a thumbnail to compare original versus transformed images
if plot_thumnail:
f, ax = plt.subplots(1, 2)
ax[0].imshow(original)
ax[0].set_title("Original")
ax[0].tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) # Hide the ticks
ax[1].imshow(transformed)
ax[1].set_title("Normalized")
ax[1].tick_params(bottom=False, left=False, labelbottom=False, labelleft=False) # Hide the ticks
plt.savefig(os.path.join(thumb_path, image), dpi=400)
plt.close()
except Exception as e:
print(image)
print(e)