-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
classification/dataset/ | ||
train/classification/dataset/ | ||
__pycache__/ | ||
*.mp4 | ||
classification/images/ | ||
classification/experiment-images/ | ||
test-preprocessing/notebooks/checkpoint/ | ||
train/classification/images/ | ||
train/classification/experiment-images/ | ||
test/test-preprocessing/notebooks/checkpoint/ | ||
*.pt | ||
*.pth | ||
*.jpg | ||
*.txt | ||
*.png | ||
*.log | ||
*.log | ||
*.zip |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import imp | ||
import torch | ||
import torch.hub | ||
import torch.nn as nn | ||
from .segmentation import * | ||
from .unet import build_unet | ||
from .utils import get_blobs, preprocess_image_classification | ||
|
||
if torch.cuda.is_available(): | ||
torch.backends.cudnn.deterministic = True | ||
|
||
# Device | ||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
# Create the model | ||
hand_seg_model = torch.hub.load( | ||
repo_or_dir='guglielmocamporese/hands-segmentation-pytorch', | ||
model='hand_segmentor', | ||
pretrained=True | ||
) | ||
hand_seg_model.eval() | ||
|
||
model_path = '../models/unet_aicityt4.pth' | ||
|
||
# Define model | ||
segmentation_model = build_unet() | ||
checkpoint = torch.load(model_path, map_location="cpu") | ||
segmentation_model.load_state_dict(checkpoint) | ||
# Send to GPU | ||
segmentation_model = segmentation_model.to(DEVICE) | ||
segmentation_model.eval() | ||
|
||
def infer_frame(frame, model, hand_seg_model, segmentation_model): | ||
img = Image.fromarray(frame).convert('RGB') | ||
img = transforms_image(img) | ||
segmented_image = crop_with_hand_entropy_seg(img, hand_seg_model, segmentation_model)["roi"] | ||
x,y,w,h = get_blobs(segmented_image) | ||
image_roi = frame[y:y+h, x:x+w] | ||
img_normalized = preprocess_image_classification(frame=image_roi[:,:,[2,1,0]]) | ||
img_normalized = img_normalized.unsqueeze_(0) | ||
img_normalized = img_normalized.to("cpu") | ||
with torch.no_grad(): | ||
model.eval() | ||
output =model(img_normalized) | ||
index = output.data.cpu().numpy().argmax() | ||
op_array = output.data.cpu().numpy() | ||
print(op_array[0][index]) | ||
return index+1 | ||
|
||
|
||
|
||
video_location = "../test-videos/" | ||
videos = ["testA_1.mp4"] | ||
images_list = [] | ||
ratios_list = [] | ||
colors_list = [] | ||
metrics_list = [] | ||
|
||
for video in videos: | ||
vidcap = cv2.VideoCapture(video_location+video) | ||
fps = vidcap.get(cv2.CAP_PROP_FPS) | ||
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
resized_image = [] | ||
|
||
for i in tqdm(range(frame_count)): | ||
try: | ||
success,image = vidcap.read() | ||
image = image[256:896, 512:1400] | ||
image = automatic_brightness_and_contrast(image) | ||
image = cv2.resize(image, dsize=(224, 224), interpolation=cv2.INTER_CUBIC) | ||
resized_image.append(image) | ||
except: | ||
pass | ||
ratios = [] | ||
colors = [] | ||
metrics = [] | ||
for img in resized_image: | ||
# ratio = get_ratio(img) | ||
color = image_colorfulness(img) | ||
# metric = color*color*ratio | ||
colors.append(color) | ||
# ratios.append(ratio) | ||
# metrics.append(metric) | ||
|
||
images_list.append(resized_image) | ||
# ratios_list.append(ratios) | ||
colors_list.append(colors) | ||
# metrics_list.append(metrics) | ||
# print(video) | ||
# plt.plot(metrics) | ||
# plt.show() | ||
# plt.plot(ratios) | ||
# plt.show() | ||
plt.plot(colors) | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
vit_base_patch32_224.pt filter=lfs diff=lfs merge=lfs -text | ||
unet_aicityt4.pth filter=lfs diff=lfs merge=lfs -text |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import shutil\n", | ||
"import os\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Fetch and Copy Dataset into folders" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get dirs\n", | ||
"datapath = \"../dataset/synthetic_images_with_tray_bg\"\n", | ||
"dirpath, dirs, filenames = next(os.walk(datapath))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for file in filenames:\n", | ||
" class_label = file.split(\"_\")[0]\n", | ||
" path_to_save = \"../dataset/train/\"+class_label\n", | ||
" if not os.path.exists(path_to_save):\n", | ||
" os.makedirs(path_to_save)\n", | ||
" shutil.copy(datapath+\"/\"+file, path_to_save)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Build Test Dataset Folder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import shutil\n", | ||
"import os\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"def get_files_from_folder(path):\n", | ||
"\n", | ||
" files = os.listdir(path)\n", | ||
" return np.asarray(files)\n", | ||
"\n", | ||
"def move(path_to_data, path_to_test_data, train_ratio):\n", | ||
" # get dirs\n", | ||
" _, dirs, _ = next(os.walk(path_to_data))\n", | ||
"\n", | ||
" # calculates how many train data per class\n", | ||
" data_counter_per_class = np.zeros((len(dirs)))\n", | ||
" for i in range(len(dirs)):\n", | ||
" path = os.path.join(path_to_data, dirs[i])\n", | ||
" files = get_files_from_folder(path)\n", | ||
" data_counter_per_class[i] = len(files)\n", | ||
" test_counter = np.round(data_counter_per_class * (1 - train_ratio))\n", | ||
"\n", | ||
" # transfers files\n", | ||
" for i in range(len(dirs)):\n", | ||
" path_to_original = os.path.join(path_to_data, dirs[i])\n", | ||
" path_to_save = os.path.join(path_to_test_data, dirs[i])\n", | ||
"\n", | ||
" #creates dir\n", | ||
" if not os.path.exists(path_to_save):\n", | ||
" os.makedirs(path_to_save)\n", | ||
" files = get_files_from_folder(path_to_original)\n", | ||
" # moves data\n", | ||
" for j in range(int(test_counter[i])):\n", | ||
" dst = os.path.join(path_to_save, files[j])\n", | ||
" src = os.path.join(path_to_original, files[j])\n", | ||
" shutil.move(src, dst)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"move(\"../dataset/train\",\"../dataset/valid\",0.8)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "aa884020ed65da342b2c717bde878a8b9bf503b8ca14673f10767b13c4c8aa94" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.8.12 ('timm-env')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.12" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |