Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep a reference to asyncio tasks in astream_chat() #17812

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/agent/legacy/react/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,10 @@ async def astream_chat(
achat_stream=response_stream, sources=self.sources
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history_task = asyncio.create_task(
chat_stream_response.awrite_response_to_history(self._memory)
)
# thread.start()

return chat_stream_response

def get_tools(self, message: str) -> List[AsyncBaseTool]:
Expand Down
2 changes: 1 addition & 1 deletion llama-index-core/llama_index/core/agent/react/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ async def _arun_step_stream(
sources=task.extra_state["sources"],
)
# create task to write chat response to history
asyncio.create_task(
agent_response_stream.awrite_response_to_history_task = asyncio.create_task(
agent_response_stream.awrite_response_to_history(
task.extra_state["new_memory"],
on_stream_end_fn=partial(self.finalize_task, task),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ async def astream_chat(
),
sources=[tool_output],
)
asyncio.create_task(response.awrite_response_to_history(self._memory))
response.awrite_response_to_history_task = asyncio.create_task(
response.awrite_response_to_history(self._memory)
)

else:
raise ValueError("Streaming is not enabled. Please use achat() instead.")
return response
Expand Down
4 changes: 3 additions & 1 deletion llama-index-core/llama_index/core/chat_engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ async def astream_chat(
chat_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(all_messages)
)
asyncio.create_task(chat_response.awrite_response_to_history(self._memory))
chat_response.awrite_response_to_history_task = asyncio.create_task(
chat_response.awrite_response_to_history(self._memory)
)

return chat_response

Expand Down
65 changes: 38 additions & 27 deletions llama-index-core/llama_index/core/chat_engine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class StreamingAgentChatResponse:
is_writing_to_memory: bool = True
# Track if an exception occurred
exception: Optional[Exception] = None
awrite_response_to_history_task: Optional[asyncio.Task] = None

def set_source_nodes(self) -> None:
if self.sources and not self.source_nodes:
Expand Down Expand Up @@ -300,34 +301,44 @@ def response_gen(self) -> Generator[str, None, None]:
self.response = self.unformatted_response.strip()

async def async_response_gen(self) -> AsyncGenerator[str, None]:
self._ensure_async_setup()
assert self.aqueue is not None

if self.is_writing_to_memory:
while True:
if not self.aqueue.empty() or not self.is_done:
if self.exception is not None:
raise self.exception

try:
delta = await asyncio.wait_for(self.aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
if self.is_done:
break
continue
if delta is not None:
self.unformatted_response += delta
yield delta
else:
break
else:
if self.achat_stream is None:
raise ValueError("achat_stream is None!")
try:
self._ensure_async_setup()
assert self.aqueue is not None

if self.is_writing_to_memory:
while True:
if not self.aqueue.empty() or not self.is_done:
if self.exception is not None:
raise self.exception

try:
delta = await asyncio.wait_for(
self.aqueue.get(), timeout=0.1
)
except asyncio.TimeoutError:
if self.is_done:
break
continue
if delta is not None:
self.unformatted_response += delta
yield delta
else:
break
else:
if self.achat_stream is None:
raise ValueError("achat_stream is None!")

async for chat_response in self.achat_stream:
self.unformatted_response += chat_response.delta or ""
yield chat_response.delta or ""
self.response = self.unformatted_response.strip()
async for chat_response in self.achat_stream:
self.unformatted_response += chat_response.delta or ""
yield chat_response.delta or ""
self.response = self.unformatted_response.strip()
finally:
if self.awrite_response_to_history_task:
# Make sure that the background task ran to completion, retrieve any exceptions
await self.awrite_response_to_history_task
self.awrite_response_to_history_task = (
None # No need to keep the reference to the finished task
)

def print_response_stream(self) -> None:
for token in self.response_gen:
Expand Down
95 changes: 95 additions & 0 deletions llama-index-core/tests/chat_engine/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
import gc
import asyncio
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.base.llms.types import (
ChatMessage,
CompletionResponse,
CompletionResponseGen,
)
from typing import Any
from llama_index.core.llms.callbacks import llm_completion_callback
from llama_index.core.llms.mock import MockLLM
import pytest
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.chat_engine.simple import SimpleChatEngine

Expand Down Expand Up @@ -34,3 +46,86 @@ def test_simple_chat_engine_with_init_history() -> None:
str(response) == "user: test human message\nassistant: test ai message\n"
"user: new human message\nassistant: "
)


@pytest.mark.asyncio()
async def test_simple_chat_engine_astream():
engine = SimpleChatEngine.from_defaults()
response = await engine.astream_chat("Hello World!")

num_iters = 0
async for response_part in response.async_response_gen():
num_iters += 1

assert num_iters > 10

assert "Hello World!" in response.unformatted_response
assert len(engine.chat_history) == 2

response = await engine.astream_chat("What is the capital of the moon?")

num_iters = 0
async for _ in response.async_response_gen():
num_iters += 1

assert num_iters > 10
assert "Hello World!" in response.unformatted_response
assert "What is the capital of the moon?" in response.unformatted_response


def test_simple_chat_engine_astream_exception_handling():
"""Test that an exception thrown while retrieving the streamed LLM response gets bubbled up to the user.
Also tests that the non-retrieved exception does not remain in an task that was not awaited leading to
a 'Task exception was never retrieved' message during garbage collection.
"""

class ExceptionThrownInTest(Exception):
pass

class ExceptionMockLLM(MockLLM):
"""Raises an exception while streaming back the mocked LLM response."""

@classmethod
def class_name(cls) -> str:
return "ExceptionMockLLM"

@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
def gen_prompt() -> CompletionResponseGen:
for ch in prompt:
yield CompletionResponse(
text=prompt,
delta=ch,
)
raise ExceptionThrownInTest("Exception thrown for testing purposes")

return gen_prompt()

async def async_part():
engine = SimpleChatEngine.from_defaults(
llm=ExceptionMockLLM(), memory=ChatMemoryBuffer.from_defaults()
)
response = await engine.astream_chat("Hello World!")

with pytest.raises(ExceptionThrownInTest):
async for response_part in response.async_response_gen():
pass

not_retrieved_exception = False

def custom_exception_handler(loop, context):
if context.get("message") == "Task exception was never retrieved":
nonlocal not_retrieved_exception
not_retrieved_exception = True

loop = asyncio.new_event_loop()
loop.set_exception_handler(custom_exception_handler)
result = loop.run_until_complete(async_part())
loop.close()
gc.collect()
if not_retrieved_exception:
pytest.fail(
"Exception was not correctly handled - ended up in asyncio cleanup performed during garbage collection"
)
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def _get_async_stream_ai_response(
sources=self.sources,
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history_task = asyncio.create_task(
chat_stream_response.awrite_response_to_history(self.memory)
)
# wait until openAI functions stop executing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def _get_async_stream_ai_response(
sources=task.extra_state["sources"],
)
# create task to write chat response to history
asyncio.create_task(
chat_stream_response.awrite_response_to_history_task = asyncio.create_task(
chat_stream_response.awrite_response_to_history(
task.extra_state["new_memory"],
on_stream_end_fn=partial(self.afinalize_task, task),
Expand Down
Loading