-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference_ocr.py
417 lines (320 loc) · 14.6 KB
/
inference_ocr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
import io
import os
import argparse
import numpy as np
from PIL import Image
from datetime import datetime
import matplotlib.pyplot as plt
from utils.constants import FUNSD_EVAL_SET, ALPHABET_SET
from utils.common import dump_dict_to_json
# Inference functions
import boto3
import easyocr
import pytesseract
from dotenv import load_dotenv
from azure.core.credentials import AzureKeyCredential
from msrest.authentication import CognitiveServicesCredentials
from azure.cognitiveservices.vision.computervision import ComputerVisionClient
from azure.cognitiveservices.vision.computervision.models import VisualFeatureTypes, OperationStatusCodes
from azure.ai.formrecognizer import FormRecognizerClient, DocumentAnalysisClient
def concatenate_images_vertically(image_files, image_dir, margin=10):
# Crop range dict
crop_metadata = {}
# Load the images
images = [Image.open(os.path.join(image_dir, i)) for i in image_files]
# Get the dimensions of the images
widths, heights = zip(*(i.size for i in images))
# Get the total height
total_height = sum(heights) + len(images) * margin
# Create a new image with the width of the widest image and the total height
concatenated_image = Image.new('RGB', (max(widths), total_height))
# Paste the images
y_offset = 0
for i, im in enumerate(images):
w, h = im.size
concatenated_image.paste(im, (0, y_offset))
# Update metadata
crop_metadata[i] = {
'image': image_files[i],
'range': (y_offset, y_offset + h),
'text': ''
}
# Update offset
y_offset += im.size[1] + margin
return concatenated_image, crop_metadata
def check_in_range(source_range, target_range, margin=10):
halved_margin = margin // 2
return source_range[0] >= target_range[0] - halved_margin and source_range[1] <= target_range[1] + halved_margin
def pad_to_50(image):
# If image is smaller than 50x50, pad it with zeros
h, w = image.shape
if h < 50:
image = np.pad(image, ((0, 50-h), (0, 0)), mode='constant', constant_values=0)
if w < 50:
image = np.pad(image, ((0, 0), (0, 50-w)), mode='constant', constant_values=0)
return (image * 255).astype(np.uint8)
def align_document_intelligence_results(results, crop_metadata, margin):
# Metadata index
metadata_index = 0
# Iterate over the results
for line in results[0].lines:
# Get min and max y coordinates
min_y = np.min(np.array(line.bounding_box)[:, 1]).astype(int)
max_y = np.max(np.array(line.bounding_box)[:, 1]).astype(int)
# Iterate across metadata
for key, value in list(crop_metadata.items())[metadata_index:]:
# If the y coordinate is within the range, update the metadata text
if check_in_range((min_y, max_y), value['range'], margin=margin):
value['text'] += line.text
# Set metadata index to current index and break out of loop
metadata_index = key
break
return crop_metadata
def document_intelligence_batched_inference(concatenated_image, concatenation_metadata, margin):
# Call model inference
image_stream = io.BytesIO()
concatenated_image.save(image_stream, format='PNG')
image_stream.seek(0)
try:
poller = form_recognizer_client.begin_recognize_content(image_stream)
complete_document_result = poller.result()
aligned_predictions = align_document_intelligence_results(complete_document_result, concatenation_metadata, margin)
# Restructure results and update outputs
predictions = {}
for key, value in aligned_predictions.items():
predictions[value['image']] = [
{
'Text': value['text']
}
]
except Exception as e:
print("Error: ", e)
predictions = {}
for key, value in concatenation_metadata.items():
predictions[value['image']] = [
{
'Text': value['text']
}
]
return predictions
def align_textract_results(results, crop_metadata, concatenated_image_dims, margin):
# Extract detected text from the response
detected_text = [block for block in results['Blocks'] if block['BlockType'] in ['LINE']]
metadata_index = 0
detected_lines = {}
# Iterate over the results and concatenate text
for line in detected_text:
y0 = int(line['Geometry']['BoundingBox']['Top'] * concatenated_image_dims[1])
y1 = y0 + int(line['Geometry']['BoundingBox']['Height'] * concatenated_image_dims[1])
# Check if y0 is within 5 of an existing key
key = next((k for k in detected_lines if abs(k - y0) <= margin // 2), None)
if key is None:
# If it's not, add a new entry to the dictionary as y0
detected_lines[y0] = {
'text': line['Text'],
'range': (y0, y1)
}
else:
# If it is, append the new text to the existing text
detected_lines[key]['text'] += ' ' + line['Text']
# Iterate again over results
for y0, line in detected_lines.items():
min_y = line['range'][0]
max_y = line['range'][1]
# Iterate across metadata
for key, value in list(crop_metadata.items())[metadata_index:]:
# If the y coordinate is within the range, update the metadata text
if check_in_range((min_y, max_y), value['range'], margin=margin):
value['text'] += line['text']
# Set metadata index to current index and break out of loop
metadata_index = key
break
return crop_metadata
def textract_batched_inference(concatenated_image, concatenation_metadata, margin):
# Convert to bytes
image_bytes_io = io.BytesIO()
concatenated_image.save(image_bytes_io, format='PNG')
image_bytes = image_bytes_io.getvalue()
try:
# Predict
response = textract_client.detect_document_text(Document={'Bytes': image_bytes})
aligned_predictions = align_textract_results(response, concatenation_metadata,
concatenated_image_dims=concatenated_image.size,
margin=margin)
# Restructure results and update outputs
predictions = {}
for key, value in aligned_predictions.items():
predictions[value['image']] = [
{
'Text': value['text']
}
]
except Exception as e:
print("Error: ", e)
predictions = {}
for key, value in concatenation_metadata.items():
predictions[value['image']] = [
{
'Text': value['text']
}
]
return predictions
def perform_batched_ocr_on_image(input_dir, dataset, model, evaluation_set, alphabet_set, output_dir="../outputs", subset=[], margin=10):
# Set metadata
timenow = str(datetime.now()).split('.')[0].replace('-','').replace(' ', '_').replace(':', '')
output_filename = f"{output_dir}/{model}_inference_{timenow}_{dataset}_dataset_crops.json"
metadata = {
'input_dir': input_dir,
'model': model,
'dataset': dataset,
'datetime': timenow
}
# Initialise predictions dict
outputs = {
'metadata': metadata,
'predictions': {}
}
# Load all image names in the directory
image_files = os.listdir(input_dir)
if subset:
image_files = [image for image in image_files if image in subset]
# Iterate across evaluation set
for evaluation_image_name in evaluation_set:
for alphabet in alphabet_set:
capture_name = f"{evaluation_image_name}_{alphabet}"
if capture_name in subset:
print(f"Skipping {capture_name}")
continue
# Select out relevant crops
images = [image for image in image_files if capture_name in image]
images.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
# Concatenate vertically and get metadata
concatenated_image, concatenation_metadata = concatenate_images_vertically(images, input_dir, margin=margin)
print(f"{str(datetime.now()).split('.')[0].replace('-','').replace(' ', '_').replace(':', '')}\tSize for image {capture_name}:\t{concatenated_image.size}")
if model == 'document_intelligence':
predictions = document_intelligence_batched_inference(concatenated_image, concatenation_metadata, margin=margin)
else:
predictions = textract_batched_inference(concatenated_image, concatenation_metadata, margin=margin)
outputs['predictions'].update(predictions)
# Output to output directory as json
dump_dict_to_json(outputs, output_filename)
print("Results stored at ", output_filename)
return outputs
def tesseract_inference(image_path):
# Read image file as binary data
img = Image.open(image_path)
# Infer text with Tesseract
text = pytesseract.image_to_string(img).strip()
return [{"Text": text}]
def easyocr_inference(image_path):
# Infer and process text with EasyOCR
detections = easyocr_reader.readtext(image_path)
text = " ".join([detection[1] for detection in detections])
return [{"Text": text}]
def azure_document_intelligence_inference(image_path):
# Read image file as binary data
image = pad_to_50(plt.imread(image_path))
image = Image.fromarray(image)
image_stream = io.BytesIO()
image.save(image_stream, format='PNG')
image_stream.seek(0)
# Call Azure Document AI to analyze the document
poller = form_recognizer_client.begin_recognize_content(image_stream)
result = poller.result()
# Extract detected text from the result
detected_text = []
for page in result:
for line in page.lines:
detected_text += [word.text for word in line.words]
# Return the detected text
return [{"Text": " ".join([word for word in detected_text])}]
def aws_textract_inference(image_path):
# Read image file as binary data
with open(image_path, 'rb') as file:
image_bytes = file.read()
# Call Textract API to detect text in the image
response = textract_client.detect_document_text(Document={'Bytes': image_bytes})
# Extract detected text from the response
detected_text = []
for block in response['Blocks']:
# Only consider blocks of type 'LINE' or 'WORD'
if block['BlockType'] in ['LINE']:
detected_text.append(block)
# Return the detected text
return detected_text
def model_inference(image_path, model):
if model == 'tesseract':
return tesseract_inference(image_path)
elif model == 'easyocr':
return easyocr_inference(image_path)
elif model == 'document_intelligence':
return azure_document_intelligence_inference(image_path)
elif model == 'textract':
return aws_textract_inference(image_path)
else:
raise ValueError(f"Model {model} not supported")
# Function to look within a directory and get all images and pass them all through a model and output resultant text to output folder
def perform_ocr_on_image(input_dir, dataset, model, output_dir="../outputs", subset=None):
# Set metadata
timenow = str(datetime.now()).split('.')[0].replace('-','').replace(' ', '_').replace(':', '')
output_filename = f"{output_dir}/{model}_inference_{timenow}_{dataset}_dataset_crops.json"
metadata = {
'input_dir': input_dir,
'model': model,
'dataset': dataset,
'datetime': timenow
}
# Initialise predictions dict
outputs = {
'metadata': metadata,
'predictions': {}
}
# Load all image names in the directory
image_files = os.listdir(input_dir)
if subset:
image_files = [image for image in image_files if image in subset]
# Iterate across image names
for image in image_files:
# Call model inference
try:
detected_text = model_inference(f"{input_dir}/{image}", model)
except Exception as e:
print(f"Text detection somehow failing on {input_dir}/{image}\n{e}")
detected_text = [{
"Text": "",
"Status": f"Text not found or errored out {e}"
}]
# Update dictionary
outputs['predictions'][image] = detected_text
# Output to output directory as json
dump_dict_to_json(outputs, output_filename)
print("Results stored at ", output_filename)
return outputs
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Perform batched OCR on images')
parser.add_argument('--input_dir', type=str, help='Input directory containing images')
parser.add_argument('--dataset', type=str, help='Dataset name')
parser.add_argument('--model', type=str, help='Model name', choices=['document_intelligence', 'textract'])
parser.add_argument('--method', type=str, help='Inference method - individual or batched', choices=['individual', 'batched'])
args = parser.parse_args()
# Azure
if args.model == "document_intelligence":
load_dotenv()
endpoint = os.getenv("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
key = os.getenv("AZURE_DOCUMENT_INTELLIGENCE_KEY")
form_recognizer_client = FormRecognizerClient(endpoint=endpoint, credential=AzureKeyCredential(key))
# AWS
elif args.model == "textract":
textract_client = boto3.client('textract')
# EasyOCR
elif args.model == "easyocr":
easyocr_reader = easyocr.Reader(['en'])
# Perofrm OCR
if args.method == 'individual':
results = perform_ocr_on_image(input_dir=args.input_dir, dataset=args.dataset, model=args.model)
else:
results = perform_batched_ocr_on_image(input_dir=args.input_dir,
dataset=args.dataset,
model=args.model,
evaluation_set=FUNSD_EVAL_SET,
alphabet_set=ALPHABET_SET)