-
Notifications
You must be signed in to change notification settings - Fork 0
/
table_generation_deplot.py
147 lines (114 loc) · 5.04 KB
/
table_generation_deplot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
import json
from PIL import Image
import os
import torch
import cv2
import requests
from io import BytesIO
TASKSET_PATH = "claim_explanation_generation_pre_tasksets.json"
DIR_DEPLOT_GEN_TABLES = "/scratch/users/k20116188/chart-fact-checking/deplot-tables"
def sharpen_image(img):
img = np.array(img)
# Define the kernel size for blurring the image
kernel_size = (3, 3)
# Define the amount of sharpening to be applied
sharpen_strength = 3
# Create a Gaussian filter kernel for blurring the image
kernel = cv2.getGaussianKernel(kernel_size[0], 0)
kernel = np.outer(kernel, kernel.transpose())
# Subtract the blurred image from the original image to obtain the sharpened image
blurred_img = cv2.filter2D(img, -1, kernel)
sharpened_img = cv2.addWeighted(img, 1 + sharpen_strength, blurred_img, -sharpen_strength, 0)
sharpened_pil_img = Image.fromarray(sharpened_img)
return sharpened_pil_img
# Load DePlot model
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
processor = AutoProcessor.from_pretrained("google/deplot")
device = "cuda" if torch.cuda.is_available() else "cpu"
processor.image_processor.is_vqa = False
model.to(device)
with open(TASKSET_PATH, "r") as f:
data = json.load(f)
print(f"Length of loaded dataset is: {len(data)} entries.")
# new_data = [] @todo ask Nikesh why this conversion necessary
# for example in data:
# try:
# imgname = os.path.basename(example["chart_img"])
# Image.open(f"ChartFC/{imgname}").convert('RGB')
# new_data.append(example)
# except Exception:
# pass
#
data = np.array(data)
np.random.seed(42)
# Shuffle the indices of the data
indices = np.random.permutation(len(data))
# Calculate the number of samples in the training, validation, and testing sets
num_train = int(0.8 * len(data))
num_val = int(0.1 * len(data))
# Split the indices into training, validation, and testing sets
train_indices = indices[:num_train]
val_indices = indices[num_train:num_train + num_val]
test_indices = indices[num_train + num_val:]
train_data = data[train_indices]
val_data = data[val_indices]
test_data = data[test_indices]
len(f"Training data length: {len(train_data)}")
with open("barchart_horizontal.json", "r") as f:
bar_horizontal = json.load(f)[0]
bar_horizontal = set([os.path.splitext(i["file_name"])[0] for i in bar_horizontal])
with open("barchart_vertical.json", "r") as f:
bar_vertical = json.load(f)[0]
bar_vertical = set([os.path.splitext(i["file_name"])[0] for i in bar_vertical])
with open("line_chart.json", "r") as f:
line_chart = json.load(f)[0]
line_chart = set([os.path.splitext(i["file_name"])[0] for i in line_chart])
with open("pie_chart.json", "r") as f:
pie_chart = json.load(f)[0]
pie_chart = set([os.path.splitext(i["file_name"])[0] for i in pie_chart])
for item in train_data:
filename = os.path.splitext(os.path.basename(item["chart_img"]))[0]
key = filename
set1 = bar_horizontal
set2 = bar_vertical
set3 = line_chart
set4 = pie_chart
if key in set1 and key not in set2 and key not in set3 and key not in set4:
item["chart_type"] = "bar_horizontal"
elif key not in set1 and key in set2 and key not in set3 and key not in set4:
item["chart_type"] = "bar_vertical"
elif key not in set1 and key not in set2 and key in set3 and key not in set4:
item["chart_type"] = "line_chart"
elif key not in set1 and key not in set2 and key not in set3 and key in set4:
item["chart_type"] = "pie_chart"
else:
item["chart_type"] = "mixed"
category = {"bar_horizontal": [], "bar_vertical": [], "line_chart": [], "pie_chart": [], "mixed": []}
for item in train_data:
category[item["chart_type"]].append(item)
print(f"len(data): {len(data)}")
for item in data:
path_table = os.path.join(DIR_DEPLOT_GEN_TABLES,
os.path.basename(item["chart_img"]) + ".txt")
if os.path.isfile(path_table):
print(f"File {path_table} already exists.")
continue
# Load image from web
try:
response = requests.get(item["chart_img"])
img = Image.open(BytesIO(response.content)).convert('RGB')
normal_inputs = processor(images=img, return_tensors="pt")
normal_generated_ids = model.generate(flattened_patches=normal_inputs["flattened_patches"].to(device),
attention_mask=normal_inputs["attention_mask"].to(device),
max_new_tokens=512)
normal_predicted_answer = processor.tokenizer.batch_decode(normal_generated_ids,
skip_special_tokens=True)[0].replace("<0x0A>", "\n")
# print(normal_predicted_answer)
except Exception as e:
print(f"Error for file {item['chart_img']}: {e}")
continue
# save deplot table
with open(path_table, "w", encoding="utf-8") as f:
f.write(normal_predicted_answer)