From e7b723c0b5a514db63d8c652b0ccc9ca2908d24c Mon Sep 17 00:00:00 2001 From: JohntheLi Date: Wed, 29 Jan 2025 12:08:34 -0800 Subject: [PATCH] add parsed PDF content to expand_text_attachments (#143) if allow_attachments=True, and expand_text_attachments=True, text, html, and pdf files contents will automatically be added to the conversation by inserting ProtocolMessages with their contents. if enable_image_comprehension=True, image descriptions will also be added to the conversation. --- pyproject.toml | 2 +- src/fastapi_poe/base.py | 16 ++++- tests/test_base.py | 125 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 tests/test_base.py diff --git a/pyproject.toml b/pyproject.toml index 3f3ad14..0f6daed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fastapi_poe" -version = "0.0.54" +version = "0.0.55" authors = [ { name="Lida Li", email="lli@quora.com" }, { name="Jelle Zijlstra", email="jelle@quora.com" }, diff --git a/src/fastapi_poe/base.py b/src/fastapi_poe/base.py index f34a35c..d3e346c 100644 --- a/src/fastapi_poe/base.py +++ b/src/fastapi_poe/base.py @@ -583,7 +583,10 @@ def insert_attachment_messages(self, query_request: QueryRequest) -> QueryReques text_attachment_messages.append( ProtocolMessage(role="user", content=url_attachment_content) ) - elif "text" in attachment.content_type: + elif ( + attachment.content_type.startswith("text/") + or attachment.content_type == "application/pdf" + ): text_attachment_content = TEXT_ATTACHMENT_TEMPLATE.format( attachment_name=attachment.name, attachment_parsed_content=attachment.parsed_content, @@ -592,8 +595,15 @@ def insert_attachment_messages(self, query_request: QueryRequest) -> QueryReques ProtocolMessage(role="user", content=text_attachment_content) ) elif "image" in attachment.content_type: - parsed_content_filename = attachment.parsed_content.split("***")[0] - parsed_content_text = attachment.parsed_content.split("***")[1] + try: + # Poe currently sends analysis in the format of filename***analysis + parsed_content_filename, parsed_content_text = ( + attachment.parsed_content.split("***", 1) + ) + except ValueError: + # If the format is not filename***analysis, use the attachment filename + parsed_content_filename = attachment.name + parsed_content_text = attachment.parsed_content image_attachment_content = IMAGE_VISION_ATTACHMENT_TEMPLATE.format( filename=parsed_content_filename, parsed_image_description=parsed_content_text, diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..57c3d1a --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,125 @@ +from fastapi_poe.base import PoeBot +from fastapi_poe.templates import ( + IMAGE_VISION_ATTACHMENT_TEMPLATE, + TEXT_ATTACHMENT_TEMPLATE, + URL_ATTACHMENT_TEMPLATE, +) +from fastapi_poe.types import Attachment, ProtocolMessage, QueryRequest + + +class TestPoeBot: + def test_insert_attachment_messages(self) -> None: + # Create mock attachments + mock_text_attachment = Attachment( + url="https://pfst.cf2.poecdn.net/base/text/test.txt", + name="test.txt", + content_type="text/plain", + parsed_content="Hello, world!", + ) + mock_image_attachment = Attachment( + url="https://pfst.cf2.poecdn.net/base/image/test.png", + name="test.png", + content_type="image/png", + parsed_content="test.png***Hello, world!", + ) + mock_image_attachment_2 = Attachment( + url="https://pfst.cf2.poecdn.net/base/image/test.png", + name="testimage2.jpg", + content_type="image/jpeg", + parsed_content="Hello, world!", + ) + mock_pdf_attachment = Attachment( + url="https://pfst.cf2.poecdn.net/base/application/test.pdf", + name="test.pdf", + content_type="application/pdf", + parsed_content="Hello, world!", + ) + mock_html_attachment = Attachment( + url="https://pfst.cf2.poecdn.net/base/text/test.html", + name="test.html", + content_type="text/html", + parsed_content="Hello, world!", + ) + mock_video_attachment = Attachment( + url="https://pfst.cf2.poecdn.net/base/video/test.mp4", + name="test.mp4", + content_type="video/mp4", + parsed_content="Hello, world!", + ) + # Create mock protocol messages + message_without_attachments = ProtocolMessage( + role="user", content="Hello, world!" + ) + message_with_attachments = ProtocolMessage( + role="user", + content="Here's some attachments", + attachments=[ + mock_text_attachment, + mock_image_attachment, + mock_image_attachment_2, + mock_pdf_attachment, + mock_html_attachment, + mock_video_attachment, + ], + ) + # Create mock query request + mock_query_request = QueryRequest( + version="1.0", + type="query", + query=[message_without_attachments, message_with_attachments], + user_id="123", + conversation_id="123", + message_id="456", + ) + + assert ( + mock_image_attachment.parsed_content + ) # satisfy pyright so split() works below + expected_protocol_messages = [ + message_without_attachments, + ProtocolMessage( + role="user", + content=TEXT_ATTACHMENT_TEMPLATE.format( + attachment_name=mock_text_attachment.name, + attachment_parsed_content=mock_text_attachment.parsed_content, + ), + ), + ProtocolMessage( + role="user", + content=TEXT_ATTACHMENT_TEMPLATE.format( + attachment_name=mock_pdf_attachment.name, + attachment_parsed_content=mock_pdf_attachment.parsed_content, + ), + ), + ProtocolMessage( + role="user", + content=URL_ATTACHMENT_TEMPLATE.format( + attachment_name=mock_html_attachment.name, + content=mock_html_attachment.parsed_content, + ), + ), + ProtocolMessage( + role="user", + content=IMAGE_VISION_ATTACHMENT_TEMPLATE.format( + filename=mock_image_attachment.parsed_content.split("***")[0], + parsed_image_description=mock_image_attachment.parsed_content.split( + "***" + )[1], + ), + ), + ProtocolMessage( + role="user", + content=IMAGE_VISION_ATTACHMENT_TEMPLATE.format( + filename=mock_image_attachment_2.name, + parsed_image_description=mock_image_attachment_2.parsed_content, + ), + ), + message_with_attachments, + ] + + # Test the insert_attachment_messages method + bot = PoeBot(bot_name="test_bot") + modified_query_request = bot.insert_attachment_messages(mock_query_request) + protocol_messages = modified_query_request.query + + assert protocol_messages == expected_protocol_messages