diff --git a/client/tests/astra-assistants/test_run_v2.py b/client/tests/astra-assistants/test_run_v2.py index 63aa2ef..b33b9c0 100644 --- a/client/tests/astra-assistants/test_run_v2.py +++ b/client/tests/astra-assistants/test_run_v2.py @@ -39,19 +39,38 @@ def run_with_assistant(assistant, client): +def create_and_run_with_assistant(assistant, client): + user_message = "What's your favorite animal." + + thread = client.beta.threads.create() + + client.beta.threads.messages.create( + thread_id=thread.id, role="user", content=user_message + ) + run = client.beta.threads.create_and_run( + thread=thread, + assistant_id=assistant.id, + ) + + logger.info(run) + + + + instructions="You're an animal expert who gives very long winded answers with flowery prose. Keep answers below 3 sentences." def test_run_gpt_4o_mini(patched_openai_client): gpt3_assistant = patched_openai_client.beta.assistants.create( name="GPT3 Animal Tutor", instructions=instructions, - model="gpt-4o_mini", + model="gpt-4o-mini", ) assistant = patched_openai_client.beta.assistants.retrieve(gpt3_assistant.id) logger.info(assistant) run_with_assistant(gpt3_assistant, patched_openai_client) + create_and_run_with_assistant(gpt3_assistant, patched_openai_client) def test_run_cohere(patched_openai_client): cohere_assistant = patched_openai_client.beta.assistants.create( @@ -91,4 +110,4 @@ def test_run_gemini(patched_openai_client): instructions=instructions, model="gemini/gemini-1.5-flash", ) - run_with_assistant(gemini_assistant, patched_openai_client) \ No newline at end of file + run_with_assistant(gemini_assistant, patched_openai_client) diff --git a/impl/routes_v2/threads_v2.py b/impl/routes_v2/threads_v2.py index 23210f2..7da69aa 100644 --- a/impl/routes_v2/threads_v2.py +++ b/impl/routes_v2/threads_v2.py @@ -112,6 +112,45 @@ async def create_thread( ) return astradb.upsert_table_from_base_model("threads", thread) +@router.post( + "/threads/runs", + responses={ + 200: {"model": RunObject, "description": "OK"}, + }, + tags=["Assistants"], + summary="Create a thread and run it in one request.", + response_model_by_alias=True, + response_model=None +) +async def create_thread_and_run( + create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""), + astradb: CassandraClient = Depends(verify_db_client), + embedding_model: str = Depends(infer_embedding_model), + embedding_api_key: str = Depends(infer_embedding_api_key), + litellm_kwargs: tuple[Dict[str, Any]] = Depends(get_litellm_kwargs), +) -> RunObject: + create_thread_request = create_thread_and_run_request.thread + if create_thread_request is None: + raise HTTPException(status_code=400, detail="thread is required.") + + thread = await create_thread(create_thread_request, astradb) + + create_run_request = CreateRunRequest( + assistant_id=create_thread_and_run_request.assistant_id, + model=create_thread_and_run_request.model, + instructions=create_thread_and_run_request.instructions, + tools=create_thread_and_run_request.tools, + metadata=create_thread_and_run_request.metadata + ) + return await create_run( + thread_id=thread.id, + create_run_request=create_run_request, + astradb=astradb, + embedding_model=embedding_model, + embedding_api_key=embedding_api_key, + litellm_kwargs=litellm_kwargs, + ) + @router.get( "/threads/{thread_id}", responses={ @@ -1823,41 +1862,4 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id): return message_delta -@router.post( - "/threads/runs", - responses={ - 200: {"model": RunObject, "description": "OK"}, - }, - tags=["Assistants"], - summary="Create a thread and run it in one request.", - response_model_by_alias=True, - response_model=None -) -async def create_thread_and_run( - create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""), - astradb: CassandraClient = Depends(verify_db_client), - embedding_model: str = Depends(infer_embedding_model), - embedding_api_key: str = Depends(infer_embedding_api_key), - litellm_kwargs: tuple[Dict[str, Any]] = Depends(get_litellm_kwargs), -) -> RunObject: - create_thread_request = create_thread_and_run_request.thread - if create_thread_request is None: - raise HTTPException(status_code=400, detail="thread is required.") - - thread = await create_thread(create_thread_request, astradb) - create_run_request = CreateRunRequest( - assistant_id=create_thread_and_run_request.assistant_id, - model=create_thread_and_run_request.model, - instructions=create_thread_and_run_request.instructions, - tools=create_thread_and_run_request.tools, - metadata=create_thread_and_run_request.metadata - ) - return await create_run( - thread_id=thread.id, - create_run_request=create_run_request, - astradb=astradb, - embedding_model=embedding_model, - embedding_api_key=embedding_api_key, - litellm_kwargs=litellm_kwargs, - )