-
Notifications
You must be signed in to change notification settings - Fork 1
/
loader.py
47 lines (37 loc) · 1.14 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import torch
import numpy as np
def read_video(filename):
video = cv2.VideoCapture(filename)
vid = []
while True:
ret, img = video.read()
if ret == True:
img = cv2.resize(img,(100,100))
vid.append(img)
else:
break
return np.array(vid)
class Firedataset(Dataset):
"""Fire dataset."""
def __init__(self, video_dir, label_dir):
"""
Args:
video_dir (string): Directory with all the videos.
label_dir (string): Path to the label data.
"""
self.video_dir = video_dir
self.filename = os.listdir(video_dir)
self.labels = np.loadtxt(label_dir)
def __len__(self):
return len(os.listdir(self.video_dir))
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
video_name = os.path.join(self.video_dir,self.filename[idx])
video = read_video(video_name)
label = self.labels[idx].repeat(video.shape[0])
sample = {'label':label, 'video':video}
return sample