-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
93 lines (78 loc) · 2.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import random
from functools import partial
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from lightning_fabric import seed_everything
from pydantic import BaseModel
from typing import List
from model.ImageTransforms import adjust_dpi, InfImageTransforms
from model.dataset import IAMDataset
from model.model import PadPool
import threading
import queue
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
from torchvision import transforms
import cv2 as cv
from InferencePipeline import InferencePipeline
from ModelLoader import ModelLoader
from model.utils import pickle_save
# ds = IAMDataset(root="/Users/tefannastasa/BachelorsWorkspace/BachModels/BachModels/data/raw/IAM", label_enc=None, parse_method="form" ,split="test")
# pickle_save(ds.label_enc, "./label_enc")
app = FastAPI()
model = ModelLoader()
pipeline = InferencePipeline(model)
transform = InfImageTransforms()
transform = transform.test_trnsf
if torch.cuda.is_available():
device = torch.device("cuda")
print("Cuda is available.")
else:
device = torch.device("cpu")
print("Cuda is not available.")
class UrlList(BaseModel):
urls: List[str]
def download_image(url):
response = requests.get(url)
if response.status_code == 200:
image_array = np.frombuffer(response.content, np.uint8)
image = cv.imdecode(image_array, cv.IMREAD_GRAYSCALE) # Read image and convert to grayscale
cv.imwrite("./image.png", image)
image = np.array(image)
image = transform(image=image)["image"]
image = torch.tensor(image).to(device)
image = image.unsqueeze(0).unsqueeze(0) # Add batch dimension
return image
else:
raise HTTPException(status_code=400, detail=f"Failed to download image from {url}")
@app.post("/predict")
async def predict(url_list: UrlList, background_tasks: BackgroundTasks):
results = []
print("Prediction request received!")
for url in url_list.urls:
input_image = download_image(url)
result_queue = pipeline.add_task(input_image)
result = result_queue.get()
result = "".join(result)
results.append(result)
print(" ".join(results))
return {"prediction": " ".join(results)}
if __name__ == "__main__":
import uvicorn
seed_everything(5234)
# for i in range(10):
# sel = random.randint(0, len(ds))
# image = ds[sel][0]
# image = transform(image=image)["image"]
# image = torch.tensor(image, dtype=torch.float32)
# print(image.size())
# image = image.unsqueeze(0).unsqueeze(0)
# result_queue = pipeline.add_task(image)
# result = "".join(result_queue.get())
# print(result)
SERVER_ADDRESS = os.environ.get("SERVER_ADDRESS", "0.0.0.0")
SERVER_PORT = os.environ.get("SERVER_PORT", "27018")
uvicorn.run(app, host=SERVER_ADDRESS, port=int(SERVER_PORT))