-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path3dcnn.py
110 lines (96 loc) · 3.15 KB
/
3dcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import cv2
from torchvision import transforms, datasets
from albumentations import (
HorizontalFlip,
VerticalFlip,
ShiftScaleRotate,
CLAHE,
RandomRotate90,
Transpose,
ShiftScaleRotate,
HueSaturationValue,
GaussNoise,
Sharpen,
Emboss,
RandomBrightnessContrast,
OneOf,
Compose,
)
import numpy as np
from PIL import Image
from decord import VideoReader, cpu
from tqdm import tqdm
from torchvision.models.video import mc3_18
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the MC3-18 model
mc3_model = mc3_18(weights=None).to(device).eval()
# Data augmentation setup
def strong_aug(p=0.5):
return Compose(
[
RandomRotate90(p=0.2),
Transpose(p=0.2),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
OneOf(
[
GaussNoise(),
],
p=0.2,
),
ShiftScaleRotate(p=0.2),
OneOf(
[
CLAHE(clip_limit=2),
Sharpen(),
Emboss(),
RandomBrightnessContrast(),
],
p=0.2,
),
HueSaturationValue(p=0.2),
],
p=p,
)
# Data loading and normalization
def load_and_preprocess_frames(video_file, num_frames=15):
vr = VideoReader(video_file, ctx=cpu(0))
step_size = max(1, len(vr) // num_frames)
frames = vr.get_batch(list(range(0, len(vr), step_size))[:num_frames]).asnumpy()
# Augment each frame individually
aug = strong_aug(p=0.9)
augmented_frames = [aug(Image.fromarray(frame)) for frame in frames]
# Normalize frames and convert to tensors
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_tensors = [transform(np.array(frame)) for frame in augmented_frames]
return input_tensors
# Perform spatiotemporal analysis
def spatiotemporal_analysis(frames_batch):
# Assume frames_batch is a list of input tensors
frames_batch = torch.stack(frames_batch).to(device)
# Assuming mc3_model is the pre-trained MC3-18 model
with torch.no_grad():
# Perform spatiotemporal analysis using the MC3-18 model
logits = mc3_model(frames_batch)
# Convert logits to probabilities using softmax
probs = torch.nn.functional.softmax(logits, dim=1)
# Assuming the model is trained for binary classification
# Use the class with the highest probability as the final prediction
predictions = torch.argmax(probs, dim=1)
# Print logits and probabilities
print(logits)
print(probs)
# Return predictions
return predictions.item()
# Example usage
video_file = 'path/to/your/video.mp4'
num_frames = 15
frames_batch = load_and_preprocess_frames(video_file, num_frames)
prediction = spatiotemporal_analysis(frames_batch)
print(f"Final Prediction: {prediction}")