-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_dataset.py
108 lines (77 loc) · 3.41 KB
/
extract_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from PIL import Image
import json
import numpy as np
import argparse
from tqdm import tqdm
def load_funsd(data_path, split='train'):
if(split=='train'):
data_path = os.path.join(data_path, 'training_data')
elif(split=='test'):
data_path = os.path.join(data_path, 'testing_data')
else:
raise Exception('invalid split')
if(os.path.exists(data_path)==False):
raise Exception('dataset not found')
instances = []
annotations_path = os.path.join(data_path, 'annotations')
images_path = os.path.join(data_path, 'images')
for annotation_file in os.listdir(annotations_path):
image_file = os.path.join(images_path, annotation_file.replace('.json','.png'))
with Image.open(image_file) as image:
image = np.array(image)
try:
annotation = json.load(open(os.path.join(annotations_path, annotation_file), encoding='utf-8'))
form = annotation['form']
for idx in range(len(form)):
element = form[idx]
bbox = element['box']
subimage = image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
text = element['text']
instance = {
'image': subimage,
'text': text,
'filename': split+'_'+annotation_file.replace('.json','')+'_'+str(idx)
}
instances.append(instance)
except Exception as e:
print('Error reading file:', annotation_file, e)
return(instances)
def save_extracted(instances, data_path):
annotations_path = os.path.join(data_path, 'annotations')
images_path = os.path.join(data_path, 'images')
if(os.path.exists(data_path)==False):
os.mkdir(data_path)
os.mkdir(images_path)
os.mkdir(annotations_path)
else:
raise Exception('Output directory:', data_path, 'already exists')
for instance in tqdm(instances):
annotation_file = os.path.join(annotations_path, instance['filename'] + '.txt')
image_file = os.path.join(images_path, instance['filename'] + '.png')
image = Image.fromarray(instance['image'])
text = instance['text']
image.save(image_file)
f = open(annotation_file, 'w', encoding='utf-8')
f.write(text)
f.close()
#input_data_path = "../../FUNSD/"
#output_data_path = "../../OCRDataset/"
#python extract_dataset.py ../../FUNSD/ ./OCRDataset/
if(__name__=='__main__'):
parser = argparse.ArgumentParser(description='Process FUNSD dataset to extract the images and annotations')
parser.add_argument('--input_path', type=str, required=True, help='Path of FUNSD')
parser.add_argument('--output_path', type=str, required=True, help='Path of the output')
args = parser.parse_args()
input_path = args.input_path
output_path = args.output_path
output_path_train = os.path.join(output_path, 'training_data')
output_path_test = os.path.join(output_path, 'testing_data')
instances_train = load_funsd(input_path, split='train')
instances_test = load_funsd(input_path, split='test')
if(os.path.exists(output_path)==False):
os.mkdir(output_path)
else:
raise Exception('Output directory:', output_path, 'already exists')
save_extracted(instances_train, output_path_train)
save_extracted(instances_test, output_path_test)