Skip to content

Commit

Permalink
Merge pull request #45 from arena-ai/44-update_chat-completion
Browse files Browse the repository at this point in the history
DRAFT: methonds that create the message out of an image or a text
  • Loading branch information
ngrislain authored Dec 11, 2024
2 parents 5179fe2 + c1d91e1 commit a8bf75b
Show file tree
Hide file tree
Showing 16 changed files with 770 additions and 155 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Add column as_process to dde
Revision ID: a8b878a502f1
Revises: 5b09eca9fc4d
Create Date: 2024-12-05 10:00:45.060030
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = 'a8b878a502f1'
down_revision = '5b09eca9fc4d'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('documentdataextractor', sa.Column('process_as', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('documentdataextractor', 'process_as')
# ### end Alembic commands ###
144 changes: 42 additions & 102 deletions backend/app/api/routes/dde.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
from typing import Any, Literal
from typing import Any
from fastapi import APIRouter, HTTPException, status, UploadFile
from fastapi.responses import JSONResponse
from sqlmodel import func, select
from sqlalchemy.exc import IntegrityError
import json
import io
import re
from pydantic import create_model, ValidationError
from pydantic import ValidationError
from app.api.deps import CurrentUser, SessionDep
from app.services import crud
from app.lm.models import (
ChatCompletionRequest,
Message as ChatCompletionMessage,
Message,
TokenLogprob,
)
from app.services.object_store import documents
from app.services.pdf_reader import pdf_reader
from app.lm.handlers import ArenaHandler
from app.ops import tup
from app.ops.documents import as_text
from app.ops.schema_converter import create_pydantic_model
from app.models import (
Message,
DocumentDataExtractorCreate,
Expand All @@ -32,10 +30,13 @@
DocumentDataExampleOut,
)
from openai.lib._pydantic import to_strict_json_schema
from app.handlers.prompt_for_image import full_prompt_from_image
from app.handlers.prompt_for_text import full_prompt_from_text

from app.models import ContentType

router = APIRouter()


@router.get("/", response_model=DocumentDataExtractorsOut)
def read_document_data_extractors(
session: SessionDep,
Expand Down Expand Up @@ -117,18 +118,21 @@ def create_document_data_extractor(
"""
try:
create_pydantic_model(document_data_extractor_in.response_template)
except KeyError:
except TypeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Incorrect type in response template: {str(e)}",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="received incorrect response template",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unexpected error: {str(e)}",
)
document_data_extractor = DocumentDataExtractor.model_validate(
document_data_extractor_in,
update={
"owner_id": current_user.id,
"response_template": json.dumps(
document_data_extractor_in.response_template
),
"response_template": document_data_extractor_in.response_template
},
)
try:
Expand Down Expand Up @@ -169,12 +173,17 @@ def update_document_data_extractor(
if pdyantic_dict is not None:
try:
create_pydantic_model(pdyantic_dict)
except KeyError:
except TypeError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="received incorrect response template",
detail=f"Incorrect type in response template: {str(e)}",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unexpected error: {str(e)}",
)
update_dict["response_template"] = json.dumps(pdyantic_dict)
update_dict["response_template"] = pdyantic_dict
document_data_extractor.sqlmodel_update(update_dict)
session.add(document_data_extractor)
session.commit()
Expand Down Expand Up @@ -262,9 +271,7 @@ def create_document_data_example(
detail="Not enough permissions",
)
# verify the example matches the template of the document data extractor
pyd_model = create_pydantic_model(
json.loads(document_data_extractor.response_template)
)
pyd_model = create_pydantic_model(document_data_extractor.response_template)
try:
pyd_model.model_validate(document_data_example_in.data)
except ValidationError:
Expand Down Expand Up @@ -426,52 +433,25 @@ async def extract_from_file(
status_code=status.HTTP_404_NOT_FOUND,
detail="DocumentDataExtractor has no owner",
)
# Build examples
examples = tup(
*(
tup(
as_text(
document_data_extractor.owner,
example.document_id,
example.start_page,
example.end_page,
),
example.data,
)
for example in document_data_extractor.document_data_examples
)
)
# Pull data from the file
if upload.content_type != "application/pdf":

try:
upload_content_type= ContentType(upload.content_type)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This endpoint can only process pdfs",
)

f = io.BytesIO(upload.file.read())
prompt = pdf_reader.as_text(f)
validate_extracted_text(prompt)
system_prompt = document_data_extractor.prompt

examples_text = ""
for input_text, output_text in await examples.evaluate():
validate_extracted_text(input_text)
examples_text += (
f"####\nINPUT: {input_text}\n\nOUTPUT: {output_text}\n\n"
)
full_system_content = f"{system_prompt}\n{examples_text}"

messages = [
ChatCompletionMessage(role="system", content=full_system_content),
ChatCompletionMessage(
role="user",
content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}",
),
]
if (upload_content_type == ContentType.PDF and document_data_extractor.process_as == "text") or upload_content_type == ContentType.XLSX or upload_content_type == ContentType.XLS:
messages = await full_prompt_from_text(f, document_data_extractor, upload_content_type)
elif (upload_content_type == ContentType.PDF and document_data_extractor.process_as == "image") or upload_content_type == ContentType.PNG:
messages = await full_prompt_from_image(f, document_data_extractor, upload_content_type)
else:
raise NotImplementedError(f'Content type {upload_content_type} not supported')

pydantic_reponse = create_pydantic_model(
json.loads(document_data_extractor.response_template)
)
pydantic_reponse = create_pydantic_model(document_data_extractor.response_template)
format_response = {
"type": "json_schema",
"json_schema": {
Expand All @@ -490,10 +470,12 @@ async def extract_from_file(
top_logprobs=5,
response_format=format_response,
).model_dump(exclude_unset=True)

chat_completion_response = await ArenaHandler(
session, document_data_extractor.owner, chat_completion_request
).process_request()

identifier = chat_completion_response.id
extracted_data = chat_completion_response.choices[0].message.content
extracted_data_token = chat_completion_response.choices[0].logprobs.content
# TODO: handle refusal or case in which content was not correctly done
Expand All @@ -505,8 +487,7 @@ async def extract_from_file(
token_indices=map_characters_to_token_indices(extracted_data_token)
regex_spans=find_value_spans(extracted_data)
logprobs_sum=get_token_spans_and_logprobs(token_indices, regex_spans, extracted_data_token)
return {"extracted_data": json.loads(json_string), "extracted_logprobs":logprobs_sum}

return {"extracted_data": json.loads(json_string), "extracted_logprobs":logprobs_sum, "identifier": identifier}

def map_characters_to_token_indices(extracted_data_token: list[TokenLogprob]) -> list[int]:
"""
Expand Down Expand Up @@ -607,46 +588,5 @@ def get_token_spans_and_logprobs(
return logprobs_for_values


def create_pydantic_model(
schema: dict[
str,
tuple[
Literal["str", "int", "bool", "float"],
Literal["required", "optional"],
],
],
) -> Any:
"""Creates a pydantic model from an input dictionary where
keys are names of entities to be retrieved, each value is a tuple specifying
the type of the entity and whether it is required or optional"""
# Convert string type names to actual Python types
field_types = {
"str": (str, ...), # ... means the field is required
"int": (int, ...),
"float": (float, ...),
"bool": (bool, ...),
}
optional_field_types = {
"str": (str | None, ...), # ... means the field is required
"int": (int | None, ...),
"float": (float | None, ...),
"bool": (bool | None, ...),
}

# Dynamically create a Pydantic model using create_model
fields = {
name: field_types[ftype[0]]
if ftype[1] == "required"
else optional_field_types[ftype[0]]
for name, ftype in schema.items()
}
dynamic_model = create_model("DataExtractorSchema", **fields)
return dynamic_model


def validate_extracted_text(text: str):
if text == "":
raise HTTPException(
status_code=500,
detail="The extracted text from the document is empty. Please check if the document is corrupted.",
)
90 changes: 90 additions & 0 deletions backend/app/handlers/prompt_for_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from app.lm.models import (
Message as ChatCompletionMessage,
)
import base64
from io import BytesIO
from typing import TypedDict, BinaryIO
from app.models import DocumentDataExtractor
from app.ops import tup
from app.ops.documents import as_png
from app.services.pdf_reader import pdf_reader
from app.services.png_reader import png_reader
from app.models import ContentType
from app.lm.models.chat_completion import ContentTypes, ImageUrlContent, TextContent

BASE64_IMAGE_PREFIX = "data:image/png;base64"


def fill_input_image(image_input: BytesIO | list[tuple[int, BytesIO]]) -> list[ContentTypes]:

def create_image_url_content(image: BytesIO) -> ImageUrlContent:
img_bytes = image.getvalue()
base64_image = (base64.b64encode(img_bytes).decode('utf-8'))
return ImageUrlContent(
type="image_url",
image_url={"url": f"{BASE64_IMAGE_PREFIX},{base64_image}"}
)

if isinstance(image_input, BytesIO):
return [create_image_url_content(image_input)]
elif isinstance(image_input, list):
content_list = []
for idx, img in image_input:
content_list.append(TextContent(type="text", text=f"####\nInput of page {idx}"))
content_list.append(create_image_url_content(img))
return content_list
else:
raise ValueError("Invalid input type for image_input")

#link to the OpenAI documentation specifying that only role 'user' is used for images:
#https://platform.openai.com/docs/guides/vision
async def full_prompt_from_image(file: BinaryIO, document_data_extractor: DocumentDataExtractor, upload_content_type: ContentType) -> list[ChatCompletionMessage]:
if upload_content_type == ContentType.PDF:
prompt = pdf_reader.as_pngs(file)
elif upload_content_type == ContentType.PNG:
prompt = png_reader.as_png(file)
else:
raise NotImplementedError(f'Content type {upload_content_type} not supported')

system_prompt = document_data_extractor.prompt

examples = tup(
*(
tup(
as_png(
document_data_extractor.owner,
example.document_id,
example.start_page,
example.end_page,
),
example.data,
)
for example in document_data_extractor.document_data_examples
)
)

messages = [
ChatCompletionMessage(
role="user",
content=[
TextContent(type="text", text=system_prompt)
]
)
]
for input_images, output_text in await examples.evaluate():
for idx, image in enumerate(input_images, start=1):
messages[0].content.append(
TextContent(type="text", text=f"####\nInput of page {idx}")
)
messages[0].content.extend(fill_input_image(image))

messages[0].content.append(
TextContent(type="text", text=f"Expected Output: {output_text}\n\n")
)

messages[0].content.append(
TextContent(type="text", text="Now, please apply the same analysis to the following new image:")
)
messages[0].content.extend(fill_input_image(prompt))

return messages
Loading

0 comments on commit a8bf75b

Please sign in to comment.