-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathapp.py
132 lines (104 loc) · 4.3 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from flask import Flask, render_template, request
import os
import torch
from PIL import Image
from werkzeug.utils import secure_filename
import uuid
import logging
from ultralytics import YOLO
from model_utils import (
convert_dicom_to_image,
transform,
extract_patch_with_yolo,
siamese_base_inference,
load_siamese_models,
)
import glob
app = Flask(__name__)
UPLOAD_FOLDER = "static/uploads"
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
logging.basicConfig(level=logging.INFO)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
YOLO_PT_AXIAL_T2_PATH = "models/axialY/best.pt"
detector = YOLO(YOLO_PT_AXIAL_T2_PATH)
SIAMESE_AXIAL_T2_PT_LIST = sorted(glob.glob("models/axial/*.pth"))
siamese_models = load_siamese_models(SIAMESE_AXIAL_T2_PT_LIST)
@app.route('/')
def home():
return render_template("home.html")
@app.route('/contact')
def contact():
return render_template("contact.html")
@app.route('/dcm_upload')
def dcm_upload():
return render_template("upload.html")
@app.route('/upload', methods=["POST"])
def upload_dcms():
if 'files' not in request.files:
return "No files uploaded"
files = request.files.getlist('files')
all_predictions = []
if not files:
return "No files selected"
for file in files:
if not file.filename:
return "File name is empty or invalid"
# Secure filename and save the file
original_filename = secure_filename(file.filename)
filename = f"{original_filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
logging.info(f"Processing file: {file.filename}")
# Preprocess the DICOM file
image = convert_dicom_to_image(filepath)
if image is None:
logging.warning(f"Failed to convert {file.filename} to image")
all_predictions.append({"file": file.filename, "error": "Failed to process DICOM file."})
continue
# YOLO detection
patches = extract_patch_with_yolo(image, detector)
logging.info(f"Detected {len(patches)} patches for {file.filename}")
if not patches:
all_predictions.append({"file": file.filename, "error": "No patches detected."})
continue
# Save patches in a folder named after the DICOM filename (without extension)
dicom_name = os.path.splitext(original_filename)[0] # Remove extension
patch_folder = os.path.join("static/patches", dicom_name)
os.makedirs(patch_folder, exist_ok=True)
patch_paths = []
for i, patch in enumerate(patches, 1):
patch_path = os.path.join(patch_folder, f"patch_{i}.png")
patch_img = Image.fromarray(patch)
patch_img.save(patch_path)
patch_paths.append(patch_path.replace("static\\", "").replace("\\", "/"))
# Siamese model predictions
predictions = {"Normal_Mild": 0.0, "Moderate": 0.0, "Severe": 0.0}
total_patches = 0
for patch in patches:
patch_tensor = transform(Image.fromarray(patch)).unsqueeze(0).to(DEVICE)
for _, models in siamese_models.items():
for model in models:
probs = siamese_base_inference(model, patch_tensor)
if isinstance(probs, str) and probs == "Uncertain":
continue # Skip this patch if uncertain
for cls, prob in probs.items():
predictions[cls] += prob
total_patches += 1
# If valid patches were processed, average predictions
if total_patches > 0:
predictions = {cls: round(prob / total_patches, 4) for cls, prob in predictions.items()}
logging.info(f"Final predictions for {file.filename}: {predictions}")
else:
all_predictions.append({"file": file.filename, "error": "No valid patches processed."})
continue
all_predictions.append({
"file": file.filename,
"normal_mild": predictions["Normal_Mild"],
"moderate": predictions["Moderate"],
"severe": predictions["Severe"],
"patches": patch_paths
})
return render_template("predictions.html", predictions=all_predictions)
if __name__ == "__main__":
app.run(debug=True)