diff --git a/README.md b/README.md index eefb891..abdc539 100644 --- a/README.md +++ b/README.md @@ -138,38 +138,6 @@ taskingai.retrieval.delete_collection(collection_id=coll.collection_id) print("Collection deleted.") ``` -### Tools - -The Tools module in TaskingAI is an essential suite designed to augment the capabilities of TaskingAI agents. Here is an example of how to create, run, and delete a tool action: - -```python -import taskingai - -# Define a schema for the tool action -OPENAPI_SCHEMA = { - # Schema definition goes here -} - -# Create a tool action based on the defined schema -actions = taskingai.tool.bulk_create_actions( - openapi_schema=OPENAPI_SCHEMA, - authentication={"type": "none"}, -) -action = actions[0] -print(f"Action created: {action.action_id}") - -# Run the action for a test purpose -result = taskingai.tool.run_action( - action_id=action.action_id, - parameters={"number": 42} -) -print(f"Action result: {result}") - -# Delete the action when done -taskingai.tool.delete_action(action_id=action.action_id) -print("Action deleted.") -``` - ## Contributing We welcome contributions of all kinds. Please read our [Contributing Guidelines](./CONTRIBUTING.md) for more information on how to get started. diff --git a/examples/assistant/chat_with_assistant.ipynb b/examples/assistant/chat_with_assistant.ipynb index 8f0f679..ddb9514 100644 --- a/examples/assistant/chat_with_assistant.ipynb +++ b/examples/assistant/chat_with_assistant.ipynb @@ -2,110 +2,49 @@ "cells": [ { "cell_type": "code", + "execution_count": null, "id": "initial_id", "metadata": { "collapsed": true }, + "outputs": [], "source": [ "import time\n", "import taskingai\n", "# Load TaskingAI API Key from environment variable" - ], - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", + "id": "4ca20b4a868dedd8", + "metadata": { + "collapsed": false + }, "source": [ "# TaskingAI: Chat with Assistant Example\n", "\n", "In this example, we will first create an assistant who knows the meaning of various numbers and will explain it in certain language.\n", "Then we will start a chat with the assistant." - ], - "metadata": { - "collapsed": false - }, - "id": "4ca20b4a868dedd8" + ] }, { "cell_type": "markdown", - "source": [ - "## Create Assistant" - ], + "id": "5e19ac923d84e898", "metadata": { "collapsed": false }, - "id": "5e19ac923d84e898" + "source": [ + "## Create Assistant" + ] }, { "cell_type": "code", - "source": [ - "from taskingai.tool import Action, ActionAuthentication, ActionAuthenticationType\n", - "from typing import List\n", - "\n", - "# create an assistant action\n", - "NUMBERS_API_SCHEMA = {\n", - " \"openapi\": \"3.0.0\",\n", - " \"info\": {\n", - " \"title\": \"Numbers API\",\n", - " \"version\": \"1.0.0\",\n", - " \"description\": \"API for fetching interesting number facts\"\n", - " },\n", - " \"servers\": [\n", - " {\n", - " \"url\": \"http://numbersapi.com\"\n", - " }\n", - " ],\n", - " \"paths\": {\n", - " \"/{number}\": {\n", - " \"get\": {\n", - " \"description\": \"Get a fact about a number\",\n", - " \"operationId\": \"getNumberFact\",\n", - " \"parameters\": [\n", - " {\n", - " \"name\": \"number\",\n", - " \"in\": \"path\",\n", - " \"required\": True,\n", - " \"description\": \"The number to get the fact for\",\n", - " \"schema\": {\n", - " \"type\": \"integer\"\n", - " }\n", - " }\n", - " ],\n", - " \"responses\": {\n", - " \"200\": {\n", - " \"description\": \"A fact about the number\",\n", - " \"content\": {\n", - " \"text/plain\": {\n", - " \"schema\": {\n", - " \"type\": \"string\"\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - "}\n", - "actions: List[Action] = taskingai.tool.bulk_create_actions(\n", - " openapi_schema=NUMBERS_API_SCHEMA,\n", - " authentication=ActionAuthentication(\n", - " type=ActionAuthenticationType.NONE,\n", - " )\n", - ")\n", - "action = actions[0]\n", - "print(f\"created action: {action}\\n\")" - ], + "execution_count": null, + "id": "3b3df0f232021283", "metadata": { "collapsed": false }, - "id": "3b2fda39ba58c5e9", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "from taskingai.assistant import Assistant, Chat, ToolRef, ToolType\n", "from taskingai.assistant.memory import AssistantMessageWindowMemory\n", @@ -118,7 +57,6 @@ " name=\"My Assistant\",\n", " description=\"A assistant who knows the meaning of various numbers.\",\n", " memory=AssistantMessageWindowMemory(\n", - " max_messages=20,\n", " max_tokens=1000\n", " ),\n", " system_prompt_template=[\n", @@ -127,49 +65,49 @@ " ],\n", " tools=[\n", " ToolRef(\n", - " type=ToolType.ACTION,\n", - " id=action.action_id,\n", + " type=ToolType.PLUGIN,\n", + " id=\"open_weather/get_hourly_forecast\",\n", " )\n", " ],\n", " retrievals=[],\n", " metadata={\"k\": \"v\"},\n", ")\n", "print(f\"created assistant: {assistant}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "3b3df0f232021283", - "outputs": [], - "execution_count": null + ] }, { "cell_type": "markdown", - "source": [ - "## Start a Chat " - ], + "id": "8e7c1b9461f0a344", "metadata": { "collapsed": false }, - "id": "8e7c1b9461f0a344" + "source": [ + "## Start a Chat " + ] }, { "cell_type": "code", + "execution_count": null, + "id": "f1e2f0b2af8b1d8d", + "metadata": { + "collapsed": false + }, + "outputs": [], "source": [ "chat: Chat = taskingai.assistant.create_chat(\n", " assistant_id=assistant.assistant_id,\n", ")\n", "print(f\"created chat: {chat.chat_id}\\n\")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b26e30b79b71697a", "metadata": { "collapsed": false }, - "id": "f1e2f0b2af8b1d8d", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "from taskingai.assistant import Message, MessageChunk\n", "user_input = input(\"User Input: \")\n", @@ -181,7 +119,7 @@ " text=user_input,\n", " )\n", " print(f\"User: {user_input}\")\n", - " \n", + "\n", " # generate assistant response\n", " assistant_message: Message = taskingai.assistant.generate_message(\n", " assistant_id=assistant.assistant_id,\n", @@ -194,16 +132,16 @@ " time.sleep(2)\n", " # quit by input 'q\n", " user_input = input(\"User: \")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7d73e0b138e3eba", "metadata": { "collapsed": false }, - "id": "b26e30b79b71697a", "outputs": [], - "execution_count": null - }, - { - "cell_type": "code", "source": [ "user_input = input(\"User Input: \")\n", "while user_input.strip() and user_input != \"q\":\n", @@ -214,7 +152,7 @@ " text=user_input,\n", " )\n", " print(f\"User: {user_input} ({user_message.message_id})\")\n", - " \n", + "\n", " # generate assistant response\n", " assistant_message_response = taskingai.assistant.generate_message(\n", " assistant_id=assistant.assistant_id,\n", @@ -224,27 +162,37 @@ " },\n", " stream=True,\n", " )\n", - " \n", - " print(f\"Assistant:\", end=\" \", flush=True)\n", + "\n", + " print(\"Assistant:\", end=\" \", flush=True)\n", " for item in assistant_message_response:\n", " if isinstance(item, MessageChunk):\n", " print(item.delta, end=\"\", flush=True)\n", " elif isinstance(item, Message):\n", " print(f\" ({item.message_id})\")\n", - " \n", + "\n", " time.sleep(2)\n", " # quit by input 'q\n", " user_input = input(\"User: \")" - ], - "metadata": { - "collapsed": false - }, - "id": "c7d73e0b138e3eba", + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a67261c", + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# clean chat context\n", + "taskingai.assistant.clean_chat_context(\n", + " assistant_id=assistant.assistant_id,\n", + " chat_id=chat.chat_id,\n", + ")" + ] }, { "cell_type": "code", + "execution_count": null, + "outputs": [], "source": [ "# list messages\n", "messages = taskingai.assistant.list_messages(\n", @@ -258,12 +206,12 @@ "metadata": { "collapsed": false }, - "id": "e94e3adb0d15373b", - "outputs": [], - "execution_count": null + "id": "e94e3adb0d15373b" }, { "cell_type": "code", + "execution_count": null, + "outputs": [], "source": [ "# delete assistant\n", "taskingai.assistant.delete_assistant(\n", @@ -273,9 +221,7 @@ "metadata": { "collapsed": false }, - "id": "ed39836bbfdc7a4e", - "outputs": [], - "execution_count": null + "id": "ed39836bbfdc7a4e" } ], "metadata": { @@ -287,14 +233,14 @@ "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.10.14" } }, "nbformat": 4, diff --git a/examples/crud/assistant_crud.ipynb b/examples/crud/assistant_crud.ipynb index 2e154a2..b7555da 100644 --- a/examples/crud/assistant_crud.ipynb +++ b/examples/crud/assistant_crud.ipynb @@ -51,7 +51,7 @@ " name=\"Customer Service Assistant\",\n", " description=\"A professional assistant for customer service.\",\n", " system_prompt_template=[\"You are a professional customer service assistant speaking {{language}}.\"],\n", - " memory={\"type\": \"naive\",},\n", + " memory={\"type\": \"message_window\",},\n", " tools=[],\n", " retrievals=[],\n", " retrieval_configs={\n", @@ -236,16 +236,34 @@ "# delete assistant\n", "taskingai.assistant.delete_assistant(assistant.assistant_id)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" } }, "nbformat": 4, diff --git a/examples/crud/retrieval_crud.ipynb b/examples/crud/retrieval_crud.ipynb index 5d8fae9..58b0e22 100644 --- a/examples/crud/retrieval_crud.ipynb +++ b/examples/crud/retrieval_crud.ipynb @@ -71,6 +71,7 @@ "# create a collection\n", "def create_collection():\n", " collection = taskingai.retrieval.create_collection(\n", + " type=\"text\",\n", " embedding_model_id=embedding_model_id,\n", " capacity=1000 # maximum text chunks can be stored \n", " )\n", @@ -289,6 +290,10 @@ { "cell_type": "code", "execution_count": null, + "id": "832ae91419da5493", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "# create a new file record\n", @@ -300,24 +305,20 @@ " text_splitter={\"type\": \"token\", \"chunk_size\": 200, \"chunk_overlap\": 20},\n", ")\n", "print(f\"created record: {record.record_id} for collection: {collection.collection_id}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "832ae91419da5493" + ] }, { "cell_type": "code", "execution_count": null, + "id": "8176058e6c15a1e0", + "metadata": { + "collapsed": false + }, "outputs": [], "source": [ "new_file = upload_file(file=open(\"../../test/files/test.docx\", \"rb\"), purpose=\"record_file\")\n", "print(f\"new uploaded file id: {new_file.file_id}\")" - ], - "metadata": { - "collapsed": false - }, - "id": "8176058e6c15a1e0" + ] }, { "cell_type": "code", @@ -422,93 +423,6 @@ "## Chunk Object" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "a395337f136500fc", - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# create a new text record\n", - "chunk = taskingai.retrieval.create_chunk(\n", - " collection_id=collection.collection_id,\n", - " content=\"The dog is a domesticated descendant of the wolf. Also called the domestic dog, it is derived from extinct gray wolves, and the gray wolf is the dog's closest living relative. The dog was the first species to be domesticated by humans.\",\n", - ")\n", - "print(f\"created chunk: {chunk.chunk_id} for collection: {collection.collection_id}\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "309e1771251bb079", - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# update chunk metadata\n", - "chunk = taskingai.retrieval.update_chunk(\n", - " collection_id=collection.collection_id,\n", - " chunk_id=chunk.chunk_id,\n", - " metadata={\"k\": \"v\"},\n", - ")\n", - "print(f\"updated chunk: {chunk}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9d68db12329b558", - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# update chunk content\n", - "chunk = taskingai.retrieval.update_chunk(\n", - " collection_id=collection.collection_id,\n", - " chunk_id=chunk.chunk_id,\n", - " content=\"New content\",\n", - ")\n", - "print(f\"updated chunk: {chunk}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d3899097cd6d0cf2", - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# get chunk\n", - "chunk = taskingai.retrieval.get_chunk(\n", - " collection_id=collection.collection_id,\n", - " chunk_id=chunk.chunk_id\n", - ")\n", - "print(f\"got chunk: {chunk}\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27e643ad8e8636ed", - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "# delete chunk\n", - "taskingai.retrieval.delete_chunk(\n", - " collection_id=collection.collection_id,\n", - " chunk_id=chunk.chunk_id,\n", - ")\n", - "print(f\"deleted chunk {chunk.chunk_id} from collection {collection.collection_id}\\n\")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -524,28 +438,21 @@ " type=\"text\",\n", " content=\"Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data. The term \\\"machine learning\\\" was coined by Arthur Samuel in 1959. In other words, machine learning enables a system to automatically learn and improve from experience without being explicitly programmed. This is achieved by feeding the system massive amounts of data, which it uses to learn patterns and make inferences. There are three main types of machine learning: 1. Supervised Learning: This is where the model is given labeled training data and the goal of learning is to generalize from the training data to unseen situations in a principled way. 2. Unsupervised Learning: This involves training on a dataset without explicit labels. The goal might be to discover inherent groupings or patterns within the data. 3. Reinforcement Learning: In this type, an agent learns to perform actions based on reward/penalty feedback to achieve a goal. It's commonly used in robotics, gaming, and navigation. Deep learning, a subset of machine learning, uses neural networks with many layers (\\\"deep\\\" structures) and has been responsible for many recent breakthroughs in AI, including speech recognition, image recognition, and natural language processing. It's important to note that machine learning is a rapidly developing field, with new techniques and applications emerging regularly.\",\n", " text_splitter={\"type\": \"token\", \"chunk_size\": 400, \"chunk_overlap\": 20},\n", - ")\n", - "\n", - "taskingai.retrieval.create_chunk(\n", - " collection_id=collection.collection_id,\n", - " content=\"The dog is a domesticated descendant of the wolf. Also called the domestic dog, it is derived from extinct gray wolves, and the gray wolf is the dog's closest living relative. The dog was the first species to be domesticated by humans.\",\n", - ")" + ")\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "55e9645ac41f8ca", - "metadata": { - "collapsed": false - }, + "id": "3d18c7fd", + "metadata": {}, "outputs": [], "source": [ - "# list chunks\n", - "chunks = taskingai.retrieval.list_chunks(collection_id=collection.collection_id)\n", - "for chunk in chunks:\n", - " print(chunk)\n", - " print(\"-\" * 50)" + "# query chunk\n", + "taskingai.retrieval.query_chunks(\n", + " collection_id=collection.collection_id,\n", + " query_text=\"text to query\"\n", + ")" ] }, { @@ -560,6 +467,16 @@ "collapsed": false }, "id": "b97aaa156f586e34" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "d7aef75e8c36fc00" } ], "metadata": { diff --git a/examples/crud/tool_crud.ipynb b/examples/crud/tool_crud.ipynb deleted file mode 100644 index 51b1337..0000000 --- a/examples/crud/tool_crud.ipynb +++ /dev/null @@ -1,233 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "initial_id", - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "import taskingai\n", - "# Load TaskingAI API Key from environment variable" - ] - }, - { - "cell_type": "markdown", - "source": [ - "# TaskingAI Tool Module CRUD Exampple" - ], - "metadata": { - "collapsed": false - }, - "id": "43e38e6ab25cd370" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "from taskingai.tool import Action" - ], - "metadata": { - "collapsed": false - }, - "id": "1da88cd4d728ced9" - }, - { - "cell_type": "markdown", - "source": [ - "## Action Object" - ], - "metadata": { - "collapsed": false - }, - "id": "5ab9d33db7c623a5" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "from taskingai.tool import ActionAuthentication, ActionAuthenticationType\n", - "from typing import List\n", - "\n", - "# create an Action\n", - "NUMBERS_API_SCHEMA = {\n", - " \"openapi\": \"3.0.0\",\n", - " \"info\": {\n", - " \"title\": \"Numbers API\",\n", - " \"version\": \"1.0.0\",\n", - " \"description\": \"API for fetching interesting number facts\"\n", - " },\n", - " \"servers\": [\n", - " {\n", - " \"url\": \"http://numbersapi.com\"\n", - " }\n", - " ],\n", - " \"paths\": {\n", - " \"/{number}\": {\n", - " \"get\": {\n", - " \"description\": \"Get fact about a number\",\n", - " \"operationId\": \"getNumberFact\",\n", - " \"parameters\": [\n", - " {\n", - " \"name\": \"number\",\n", - " \"in\": \"path\",\n", - " \"required\": True,\n", - " \"description\": \"The number to get the fact for\",\n", - " \"schema\": {\n", - " \"type\": \"integer\"\n", - " }\n", - " }\n", - " ],\n", - " \"responses\": {\n", - " \"200\": {\n", - " \"description\": \"A fact about the number\",\n", - " \"content\": {\n", - " \"text/plain\": {\n", - " \"schema\": {\n", - " \"type\": \"string\"\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - " }\n", - "}\n", - "actions: List[Action] = taskingai.tool.bulk_create_actions(\n", - " openapi_schema=NUMBERS_API_SCHEMA,\n", - " authentication=ActionAuthentication(\n", - " type=ActionAuthenticationType.NONE,\n", - " )\n", - ")\n", - "action = actions[0]\n", - "print(f\"created action: {action}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "1b40bb3464107aa6" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# get action\n", - "action_id: str = action.action_id\n", - "action: Action = taskingai.tool.get_action(\n", - " action_id=action_id\n", - ")\n", - "\n", - "print(f\"got action: {action}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "e991bb600ca2bca9" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# update action\n", - "NUMBERS_API_SCHEMA[\"paths\"][\"/{number}\"][\"get\"][\"summary\"] = \"Get fun fact about a number)\"\n", - "action: Action = taskingai.tool.update_action(\n", - " action_id=action_id,\n", - " openapi_schema=NUMBERS_API_SCHEMA\n", - ")\n", - "\n", - "print(f\"updated action: {action}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "495db38e51c0531f" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# run action for test purpose\n", - "result = taskingai.tool.run_action(\n", - " action_id=action_id,\n", - " parameters={\n", - " \"number\": 127\n", - " }\n", - ")\n", - "print(f\"ran action result: {result}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "36d4c722d4b9bd74" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# delete action\n", - "taskingai.tool.delete_action(action_id=action_id)\n", - "print(f\"deleted action: {action_id}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "dd53cb15efa35298" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# list actions\n", - "actions = taskingai.tool.list_actions()\n", - "action_ids = [action.action_id for action in actions]\n", - "# ensure the action we deleted is not in the list\n", - "print(f\"f{action_id} in action_ids: {action_id in action_ids}\\n\")" - ], - "metadata": { - "collapsed": false - }, - "id": "5a1a36d15055918f" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "5588d5e7457225be" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/inference/chat_completion.ipynb b/examples/inference/chat_completion.ipynb index 4253aa5..56cecc6 100644 --- a/examples/inference/chat_completion.ipynb +++ b/examples/inference/chat_completion.ipynb @@ -31,7 +31,7 @@ "from taskingai.inference import *\n", "import json\n", "# choose an available chat_completion model from your project\n", - "model_id = \"YOUR_MODEL_ID\"" + "model_id = \"YOUR_CHAT_COMPLETION_MODEL_ID\"" ], "metadata": { "collapsed": false @@ -265,6 +265,16 @@ "collapsed": false }, "id": "4f3290f87635152a" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "f9a7066d18575dc9" } ], "metadata": { diff --git a/taskingai/__init__.py b/taskingai/__init__.py index 0113196..a08f2b2 100644 --- a/taskingai/__init__.py +++ b/taskingai/__init__.py @@ -1,6 +1,5 @@ from .config import * from . import assistant -from . import tool from . import retrieval from . import inference from . import file diff --git a/taskingai/_version.py b/taskingai/_version.py index 635fd03..55cd3d9 100644 --- a/taskingai/_version.py +++ b/taskingai/_version.py @@ -1,2 +1,2 @@ __title__ = "taskingai" -__version__ = "0.2.5" +__version__ = "0.3.0" diff --git a/taskingai/assistant/chat.py b/taskingai/assistant/chat.py index 578f5ec..9131226 100644 --- a/taskingai/assistant/chat.py +++ b/taskingai/assistant/chat.py @@ -15,6 +15,8 @@ "a_create_chat", "a_update_chat", "a_delete_chat", + "clean_chat_context", + "a_clean_chat_context", ] @@ -238,3 +240,27 @@ async def a_delete_chat( """ await async_api_delete_chat(assistant_id=assistant_id, chat_id=chat_id) + + +def clean_chat_context(assistant_id: str, chat_id: str) -> Message: + """ + Clean chat context. + + :param assistant_id: The ID of the assistant. + :param chat_id: The ID of the chat. + """ + + response = api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + return response.data + + +async def a_clean_chat_context(assistant_id: str, chat_id: str) -> Message: + """ + Clean chat context in async mode. + + :param assistant_id: The ID of the assistant. + :param chat_id: The ID of the chat. + """ + + response = await async_api_clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + return response.data diff --git a/taskingai/assistant/memory.py b/taskingai/assistant/memory.py index 3f94eb6..074c5b1 100644 --- a/taskingai/assistant/memory.py +++ b/taskingai/assistant/memory.py @@ -6,22 +6,10 @@ __all__ = [ "AssistantMemory", "AssistantMemoryType", - "AssistantNaiveMemory", - "AssistantZeroMemory", "AssistantMessageWindowMemory", ] -class AssistantNaiveMemory(AssistantMemory): - def __init__(self): - super().__init__(type=AssistantMemoryType.NAIVE) - - -class AssistantZeroMemory(AssistantMemory): - def __init__(self): - super().__init__(type=AssistantMemoryType.ZERO) - - class AssistantMessageWindowMemory(AssistantMemory): - def __init__(self, max_messages: int, max_tokens: int): - super().__init__(type=AssistantMemoryType.MESSAGE_WINDOW, max_messages=max_messages, max_tokens=max_tokens) + def __init__(self, max_tokens: int): + super().__init__(type=AssistantMemoryType.MESSAGE_WINDOW, max_tokens=max_tokens) diff --git a/taskingai/client/apis/__init__.py b/taskingai/client/apis/__init__.py index cf6503e..752553f 100644 --- a/taskingai/client/apis/__init__.py +++ b/taskingai/client/apis/__init__.py @@ -11,43 +11,33 @@ License: Apache 2.0 """ -from .api_bulk_create_actions import * from .api_chat_completion import * +from .api_clean_chat_context import * from .api_create_assistant import * from .api_create_chat import * -from .api_create_chunk import * from .api_create_collection import * from .api_create_message import * from .api_create_record import * -from .api_delete_action import * from .api_delete_assistant import * from .api_delete_chat import * -from .api_delete_chunk import * from .api_delete_collection import * from .api_delete_message import * from .api_delete_record import * from .api_generate_message import * -from .api_get_action import * from .api_get_assistant import * from .api_get_chat import * -from .api_get_chunk import * from .api_get_collection import * from .api_get_message import * from .api_get_record import * -from .api_list_actions import * from .api_list_assistants import * from .api_list_chats import * -from .api_list_chunks import * from .api_list_collections import * from .api_list_messages import * from .api_list_records import * from .api_query_collection_chunks import * -from .api_run_action import * from .api_text_embedding import * -from .api_update_action import * from .api_update_assistant import * from .api_update_chat import * -from .api_update_chunk import * from .api_update_collection import * from .api_update_message import * from .api_update_record import * diff --git a/taskingai/client/apis/api_bulk_create_actions.py b/taskingai/client/apis/api_bulk_create_actions.py deleted file mode 100644 index 88e77b6..0000000 --- a/taskingai/client/apis/api_bulk_create_actions.py +++ /dev/null @@ -1,83 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_bulk_create_actions.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ActionBulkCreateRequest, ActionBulkCreateResponse - -__all__ = ["api_bulk_create_actions", "async_api_bulk_create_actions"] - - -def api_bulk_create_actions( - payload: ActionBulkCreateRequest, - **kwargs, -) -> ActionBulkCreateResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = {} - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/actions/bulk_create", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionBulkCreateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_bulk_create_actions( - payload: ActionBulkCreateRequest, - **kwargs, -) -> ActionBulkCreateResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = {} - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/actions/bulk_create", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionBulkCreateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_delete_chunk.py b/taskingai/client/apis/api_clean_chat_context.py similarity index 68% rename from taskingai/client/apis/api_delete_chunk.py rename to taskingai/client/apis/api_clean_chat_context.py index 914b03f..736e1c9 100644 --- a/taskingai/client/apis/api_delete_chunk.py +++ b/taskingai/client/apis/api_clean_chat_context.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# api_delete_chunk.py +# api_create_assistant.py """ This script is automatically generated for TaskingAI python client @@ -12,23 +12,23 @@ """ from ..utils import get_api_client -from ..models import BaseEmptyResponse +from ..models import ChatCleanContextResponse -__all__ = ["api_delete_chunk", "async_api_delete_chunk"] +__all__ = ["api_clean_chat_context", "async_api_clean_chat_context"] -def api_delete_chunk( - collection_id: str, - chunk_id: str, +def api_clean_chat_context( + assistant_id: str, + chat_id: str, **kwargs, -) -> BaseEmptyResponse: +) -> ChatCleanContextResponse: # get api client sync_api_client = get_api_client(async_client=False) # request parameters path_params_dict = { - "collection_id": collection_id, - "chunk_id": chunk_id, + "assistant_id": assistant_id, + "chat_id": chat_id, } query_params_dict = {} header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} @@ -37,15 +37,15 @@ def api_delete_chunk( # execute the request return sync_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks/{chunk_id}", - method="DELETE", + resource_path="/v1/assistants/{assistant_id}/chats/{chat_id}/clean_context", + method="POST", path_params=path_params_dict, query_params=query_params_dict, header_params=header_params_dict, body=body_params_dict, post_params=[], files=files_dict, - response_type=BaseEmptyResponse, + response_type=ChatCleanContextResponse, auth_settings=[], _return_http_data_only=True, _preload_content=True, @@ -54,18 +54,18 @@ def api_delete_chunk( ) -async def async_api_delete_chunk( - collection_id: str, - chunk_id: str, +async def async_api_clean_chat_context( + assistant_id: str, + chat_id: str, **kwargs, -) -> BaseEmptyResponse: +) -> ChatCleanContextResponse: # get api client async_api_client = get_api_client(async_client=True) # request parameters path_params_dict = { - "collection_id": collection_id, - "chunk_id": chunk_id, + "assistant_id": assistant_id, + "chat_id": chat_id, } query_params_dict = {} header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} @@ -74,15 +74,15 @@ async def async_api_delete_chunk( # execute the request return await async_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks/{chunk_id}", - method="DELETE", + resource_path="/v1/assistants/{assistant_id}/chats/{chat_id}/clean_context", + method="POST", path_params=path_params_dict, query_params=query_params_dict, header_params=header_params_dict, body=body_params_dict, post_params=[], files=files_dict, - response_type=BaseEmptyResponse, + response_type=ChatCleanContextResponse, auth_settings=[], _return_http_data_only=True, _preload_content=True, diff --git a/taskingai/client/apis/api_create_chunk.py b/taskingai/client/apis/api_create_chunk.py deleted file mode 100644 index 606e768..0000000 --- a/taskingai/client/apis/api_create_chunk.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_create_chunk.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ChunkCreateRequest, ChunkCreateResponse - -__all__ = ["api_create_chunk", "async_api_create_chunk"] - - -def api_create_chunk( - collection_id: str, - payload: ChunkCreateRequest, - **kwargs, -) -> ChunkCreateResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkCreateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_create_chunk( - collection_id: str, - payload: ChunkCreateRequest, - **kwargs, -) -> ChunkCreateResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkCreateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_delete_action.py b/taskingai/client/apis/api_delete_action.py deleted file mode 100644 index 9a7b676..0000000 --- a/taskingai/client/apis/api_delete_action.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_delete_action.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import BaseEmptyResponse - -__all__ = ["api_delete_action", "async_api_delete_action"] - - -def api_delete_action( - action_id: str, - **kwargs, -) -> BaseEmptyResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/actions/{action_id}", - method="DELETE", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=BaseEmptyResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_delete_action( - action_id: str, - **kwargs, -) -> BaseEmptyResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/actions/{action_id}", - method="DELETE", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=BaseEmptyResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_get_action.py b/taskingai/client/apis/api_get_action.py deleted file mode 100644 index 69566a6..0000000 --- a/taskingai/client/apis/api_get_action.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_get_action.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ActionGetResponse - -__all__ = ["api_get_action", "async_api_get_action"] - - -def api_get_action( - action_id: str, - **kwargs, -) -> ActionGetResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/actions/{action_id}", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionGetResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_get_action( - action_id: str, - **kwargs, -) -> ActionGetResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/actions/{action_id}", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionGetResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_get_chunk.py b/taskingai/client/apis/api_get_chunk.py deleted file mode 100644 index 69277a3..0000000 --- a/taskingai/client/apis/api_get_chunk.py +++ /dev/null @@ -1,91 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_get_chunk.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ChunkGetResponse - -__all__ = ["api_get_chunk", "async_api_get_chunk"] - - -def api_get_chunk( - collection_id: str, - chunk_id: str, - **kwargs, -) -> ChunkGetResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - "chunk_id": chunk_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks/{chunk_id}", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkGetResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_get_chunk( - collection_id: str, - chunk_id: str, - **kwargs, -) -> ChunkGetResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - "chunk_id": chunk_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks/{chunk_id}", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkGetResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_list_actions.py b/taskingai/client/apis/api_list_actions.py deleted file mode 100644 index 96a10e8..0000000 --- a/taskingai/client/apis/api_list_actions.py +++ /dev/null @@ -1,83 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_list_actions.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client, convert_query_params_dict -from ..models import ActionListRequest, ActionListResponse - -__all__ = ["api_list_actions", "async_api_list_actions"] - - -def api_list_actions( - payload: ActionListRequest, - **kwargs, -) -> ActionListResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = {} - query_params_dict = convert_query_params_dict(payload.model_dump()) - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/actions", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionListResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_list_actions( - payload: ActionListRequest, - **kwargs, -) -> ActionListResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = {} - query_params_dict = convert_query_params_dict(payload.model_dump()) - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/actions", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionListResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_list_chunks.py b/taskingai/client/apis/api_list_chunks.py deleted file mode 100644 index 508136d..0000000 --- a/taskingai/client/apis/api_list_chunks.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_list_chunks.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client, convert_query_params_dict -from ..models import ChunkListRequest, ChunkListResponse - -__all__ = ["api_list_chunks", "async_api_list_chunks"] - - -def api_list_chunks( - collection_id: str, - payload: ChunkListRequest, - **kwargs, -) -> ChunkListResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - } - query_params_dict = convert_query_params_dict(payload.model_dump()) - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkListResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_list_chunks( - collection_id: str, - payload: ChunkListRequest, - **kwargs, -) -> ChunkListResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - } - query_params_dict = convert_query_params_dict(payload.model_dump()) - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = {} - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks", - method="GET", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkListResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_run_action.py b/taskingai/client/apis/api_run_action.py deleted file mode 100644 index 7c4542e..0000000 --- a/taskingai/client/apis/api_run_action.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_run_action.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ActionRunRequest, ActionRunResponse - -__all__ = ["api_run_action", "async_api_run_action"] - - -def api_run_action( - action_id: str, - payload: ActionRunRequest, - **kwargs, -) -> ActionRunResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/actions/{action_id}/run", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionRunResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_run_action( - action_id: str, - payload: ActionRunRequest, - **kwargs, -) -> ActionRunResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/actions/{action_id}/run", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionRunResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_update_action.py b/taskingai/client/apis/api_update_action.py deleted file mode 100644 index 7206239..0000000 --- a/taskingai/client/apis/api_update_action.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_update_action.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ActionUpdateRequest, ActionUpdateResponse - -__all__ = ["api_update_action", "async_api_update_action"] - - -def api_update_action( - action_id: str, - payload: ActionUpdateRequest, - **kwargs, -) -> ActionUpdateResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/actions/{action_id}", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionUpdateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_update_action( - action_id: str, - payload: ActionUpdateRequest, - **kwargs, -) -> ActionUpdateResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "action_id": action_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/actions/{action_id}", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ActionUpdateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/apis/api_update_chunk.py b/taskingai/client/apis/api_update_chunk.py deleted file mode 100644 index a6e70b9..0000000 --- a/taskingai/client/apis/api_update_chunk.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- - -# api_update_chunk.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from ..utils import get_api_client -from ..models import ChunkUpdateRequest, ChunkUpdateResponse - -__all__ = ["api_update_chunk", "async_api_update_chunk"] - - -def api_update_chunk( - collection_id: str, - chunk_id: str, - payload: ChunkUpdateRequest, - **kwargs, -) -> ChunkUpdateResponse: - # get api client - sync_api_client = get_api_client(async_client=False) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - "chunk_id": chunk_id, - } - query_params_dict = {} - header_params_dict = {"Accept": sync_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return sync_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks/{chunk_id}", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkUpdateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) - - -async def async_api_update_chunk( - collection_id: str, - chunk_id: str, - payload: ChunkUpdateRequest, - **kwargs, -) -> ChunkUpdateResponse: - # get api client - async_api_client = get_api_client(async_client=True) - - # request parameters - path_params_dict = { - "collection_id": collection_id, - "chunk_id": chunk_id, - } - query_params_dict = {} - header_params_dict = {"Accept": async_api_client.select_header_accept(["application/json"])} - body_params_dict = payload.model_dump() - files_dict = {} - - # execute the request - return await async_api_client.call_api( - resource_path="/v1/collections/{collection_id}/chunks/{chunk_id}", - method="POST", - path_params=path_params_dict, - query_params=query_params_dict, - header_params=header_params_dict, - body=body_params_dict, - post_params=[], - files=files_dict, - response_type=ChunkUpdateResponse, - auth_settings=[], - _return_http_data_only=True, - _preload_content=True, - _request_timeout=kwargs.get("timeout"), - collection_formats={}, - ) diff --git a/taskingai/client/models/entities/__init__.py b/taskingai/client/models/entities/__init__.py index e2020dd..a6c6fc6 100644 --- a/taskingai/client/models/entities/__init__.py +++ b/taskingai/client/models/entities/__init__.py @@ -11,12 +11,6 @@ License: Apache 2.0 """ -from .action import * -from .action_authentication import * -from .action_authentication_type import * -from .action_body_type import * -from .action_method import * -from .action_param import * from .assistant import * from .assistant_memory import * from .assistant_memory_type import * @@ -36,8 +30,6 @@ from .chat_completion_system_message import * from .chat_completion_usage import * from .chat_completion_user_message import * -from .chat_memory import * -from .chat_memory_message import * from .chunk import * from .collection import * from .file_id_data import * diff --git a/taskingai/client/models/entities/action.py b/taskingai/client/models/entities/action.py deleted file mode 100644 index 9ec79b9..0000000 --- a/taskingai/client/models/entities/action.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- - -# action.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Any, Dict -from .action_method import ActionMethod -from .action_body_type import ActionBodyType -from .action_authentication import ActionAuthentication - -__all__ = ["Action"] - - -class Action(BaseModel): - object: str = Field("Action") - action_id: str = Field(..., min_length=20, max_length=30) - name: str = Field(..., min_length=1, max_length=128) - operation_id: str = Field(...) - description: str = Field(..., min_length=1, max_length=512) - url: str = Field(...) - method: ActionMethod = Field(...) - body_type: ActionBodyType = Field(...) - openapi_schema: Dict[str, Any] = Field(...) - authentication: ActionAuthentication = Field(...) - updated_timestamp: int = Field(..., ge=0) - created_timestamp: int = Field(..., ge=0) diff --git a/taskingai/client/models/entities/action_authentication.py b/taskingai/client/models/entities/action_authentication.py deleted file mode 100644 index f5b0daf..0000000 --- a/taskingai/client/models/entities/action_authentication.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_authentication.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional, Dict -from .action_authentication_type import ActionAuthenticationType - -__all__ = ["ActionAuthentication"] - - -class ActionAuthentication(BaseModel): - type: ActionAuthenticationType = Field(...) - secret: Optional[str] = Field(None, min_length=1, max_length=1024) - content: Optional[Dict] = Field(None) diff --git a/taskingai/client/models/entities/action_authentication_type.py b/taskingai/client/models/entities/action_authentication_type.py deleted file mode 100644 index 2af9d3f..0000000 --- a/taskingai/client/models/entities/action_authentication_type.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_authentication_type.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from enum import Enum - -__all__ = ["ActionAuthenticationType"] - - -class ActionAuthenticationType(str, Enum): - BEARER = "bearer" - BASIC = "basic" - CUSTOM = "custom" - NONE = "none" diff --git a/taskingai/client/models/entities/action_body_type.py b/taskingai/client/models/entities/action_body_type.py deleted file mode 100644 index 0f46243..0000000 --- a/taskingai/client/models/entities/action_body_type.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_body_type.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from enum import Enum - -__all__ = ["ActionBodyType"] - - -class ActionBodyType(str, Enum): - JSON = "JSON" - FORM = "FORM" - NONE = "NONE" diff --git a/taskingai/client/models/entities/action_method.py b/taskingai/client/models/entities/action_method.py deleted file mode 100644 index c6038b0..0000000 --- a/taskingai/client/models/entities/action_method.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_method.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from enum import Enum - -__all__ = ["ActionMethod"] - - -class ActionMethod(str, Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - NONE = "NONE" diff --git a/taskingai/client/models/entities/action_param.py b/taskingai/client/models/entities/action_param.py deleted file mode 100644 index 7d47e92..0000000 --- a/taskingai/client/models/entities/action_param.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_param.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional, List - - -__all__ = ["ActionParam"] - - -class ActionParam(BaseModel): - type: str = Field(...) - description: str = Field(...) - enum: Optional[List[str]] = Field(None) - required: bool = Field(...) diff --git a/taskingai/client/models/entities/assistant_memory.py b/taskingai/client/models/entities/assistant_memory.py index e71f71f..8336b48 100644 --- a/taskingai/client/models/entities/assistant_memory.py +++ b/taskingai/client/models/entities/assistant_memory.py @@ -20,5 +20,4 @@ class AssistantMemory(BaseModel): type: AssistantMemoryType = Field(...) - max_messages: Optional[int] = Field(None, ge=1, le=1024) - max_tokens: Optional[int] = Field(None, ge=1, le=8192) + max_tokens: Optional[int] = Field(None, ge=0, le=8192) diff --git a/taskingai/client/models/entities/assistant_memory_type.py b/taskingai/client/models/entities/assistant_memory_type.py index 8ab833a..77ade7e 100644 --- a/taskingai/client/models/entities/assistant_memory_type.py +++ b/taskingai/client/models/entities/assistant_memory_type.py @@ -17,6 +17,4 @@ class AssistantMemoryType(str, Enum): - ZERO = "zero" - NAIVE = "naive" MESSAGE_WINDOW = "message_window" diff --git a/taskingai/client/models/entities/chat_memory.py b/taskingai/client/models/entities/chat_memory.py deleted file mode 100644 index 53c5248..0000000 --- a/taskingai/client/models/entities/chat_memory.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- - -# chat_memory.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional, List -from .assistant_memory_type import AssistantMemoryType -from .chat_memory_message import ChatMemoryMessage - -__all__ = ["ChatMemory"] - - -class ChatMemory(BaseModel): - type: AssistantMemoryType = Field(...) - messages: List[ChatMemoryMessage] = Field([]) - max_messages: Optional[int] = Field(None, ge=1, le=1024) - max_tokens: Optional[int] = Field(None, ge=1, le=8192) diff --git a/taskingai/client/models/entities/chat_memory_message.py b/taskingai/client/models/entities/chat_memory_message.py deleted file mode 100644 index 1ad9dfc..0000000 --- a/taskingai/client/models/entities/chat_memory_message.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -# chat_memory_message.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional - - -__all__ = ["ChatMemoryMessage"] - - -class ChatMemoryMessage(BaseModel): - role: str = Field(...) - content: str = Field(...) - token_count: Optional[int] = Field(None) diff --git a/taskingai/client/models/entities/chunk.py b/taskingai/client/models/entities/chunk.py index 062b03d..bdfa31d 100644 --- a/taskingai/client/models/entities/chunk.py +++ b/taskingai/client/models/entities/chunk.py @@ -23,7 +23,9 @@ class Chunk(BaseModel): chunk_id: str = Field(..., min_length=20, max_length=30) record_id: Optional[str] = Field(..., min_length=20, max_length=30) collection_id: str = Field(..., min_length=20, max_length=30) - content: str = Field(..., min_length=1) + content: Optional[str] = Field(None) + question: Optional[str] = Field(None) + answer: Optional[str] = Field(None) num_tokens: int = Field(..., ge=0) metadata: Dict = Field({}, min_length=0, max_length=16) score: Optional[float] = Field(None, ge=0, le=1) diff --git a/taskingai/client/models/entities/collection.py b/taskingai/client/models/entities/collection.py index c82f215..269b1c2 100644 --- a/taskingai/client/models/entities/collection.py +++ b/taskingai/client/models/entities/collection.py @@ -11,21 +11,26 @@ License: Apache 2.0 """ +from enum import Enum from pydantic import BaseModel, Field from typing import Dict -__all__ = ["Collection"] +__all__ = ["Collection", "CollectionType"] + + +class CollectionType(str, Enum): + TEXT = "text" + QA = "qa" class Collection(BaseModel): object: str = Field("Collection") + type: CollectionType = Field(CollectionType.TEXT) collection_id: str = Field(..., min_length=24, max_length=24) name: str = Field("", min_length=0, max_length=256) description: str = Field("", min_length=0, max_length=512) capacity: int = Field(1000, ge=1) - num_records: int = Field(..., ge=0) - num_chunks: int = Field(..., ge=0) embedding_model_id: str = Field(..., min_length=8, max_length=8) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) updated_timestamp: int = Field(..., ge=0) diff --git a/taskingai/client/models/entities/record_type.py b/taskingai/client/models/entities/record_type.py index 9235603..c6218d9 100644 --- a/taskingai/client/models/entities/record_type.py +++ b/taskingai/client/models/entities/record_type.py @@ -20,3 +20,4 @@ class RecordType(str, Enum): TEXT = "text" FILE = "file" WEB = "web" + QA_SHEET = "qa_sheet" diff --git a/taskingai/client/models/entities/upload_file_purpose.py b/taskingai/client/models/entities/upload_file_purpose.py index 1296b78..0e0a895 100644 --- a/taskingai/client/models/entities/upload_file_purpose.py +++ b/taskingai/client/models/entities/upload_file_purpose.py @@ -18,3 +18,4 @@ class UploadFilePurpose(str, Enum): RECORD_FILE = "record_file" + QA_RECORD_FILE = "qa_record_file" diff --git a/taskingai/client/models/schemas/__init__.py b/taskingai/client/models/schemas/__init__.py index 5300362..c3de815 100644 --- a/taskingai/client/models/schemas/__init__.py +++ b/taskingai/client/models/schemas/__init__.py @@ -11,15 +11,6 @@ License: Apache 2.0 """ -from .action_bulk_create_request import * -from .action_bulk_create_response import * -from .action_get_response import * -from .action_list_request import * -from .action_list_response import * -from .action_run_request import * -from .action_run_response import * -from .action_update_request import * -from .action_update_response import * from .assistant_create_request import * from .assistant_create_response import * from .assistant_get_response import * @@ -29,6 +20,7 @@ from .assistant_update_response import * from .base_data_response import * from .base_empty_response import * +from .chat_clean_context_response import * from .chat_completion_request import * from .chat_completion_response import * from .chat_create_request import * @@ -38,15 +30,8 @@ from .chat_list_response import * from .chat_update_request import * from .chat_update_response import * -from .chunk_create_request import * -from .chunk_create_response import * -from .chunk_get_response import * -from .chunk_list_request import * -from .chunk_list_response import * from .chunk_query_request import * from .chunk_query_response import * -from .chunk_update_request import * -from .chunk_update_response import * from .collection_create_request import * from .collection_create_response import * from .collection_get_response import * diff --git a/taskingai/client/models/schemas/action_bulk_create_request.py b/taskingai/client/models/schemas/action_bulk_create_request.py deleted file mode 100644 index 0c3c971..0000000 --- a/taskingai/client/models/schemas/action_bulk_create_request.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_bulk_create_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Dict -from ..entities.action_authentication import ActionAuthentication - -__all__ = ["ActionBulkCreateRequest"] - - -class ActionBulkCreateRequest(BaseModel): - openapi_schema: Dict = Field(...) - authentication: ActionAuthentication = Field(...) diff --git a/taskingai/client/models/schemas/action_bulk_create_response.py b/taskingai/client/models/schemas/action_bulk_create_response.py deleted file mode 100644 index 36b3681..0000000 --- a/taskingai/client/models/schemas/action_bulk_create_response.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_bulk_create_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import List -from ..entities.action import Action - -__all__ = ["ActionBulkCreateResponse"] - - -class ActionBulkCreateResponse(BaseModel): - status: str = Field("success") - data: List[Action] = Field(...) diff --git a/taskingai/client/models/schemas/action_get_response.py b/taskingai/client/models/schemas/action_get_response.py deleted file mode 100644 index e9c6159..0000000 --- a/taskingai/client/models/schemas/action_get_response.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_get_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from ..entities.action import Action - -__all__ = ["ActionGetResponse"] - - -class ActionGetResponse(BaseModel): - status: str = Field("success") - data: Action = Field(...) diff --git a/taskingai/client/models/schemas/action_list_request.py b/taskingai/client/models/schemas/action_list_request.py deleted file mode 100644 index 7aecf68..0000000 --- a/taskingai/client/models/schemas/action_list_request.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_list_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional -from ..entities.sort_order_enum import SortOrderEnum - -__all__ = ["ActionListRequest"] - - -class ActionListRequest(BaseModel): - limit: int = Field(20) - order: Optional[SortOrderEnum] = Field(None) - after: Optional[str] = Field(None) - before: Optional[str] = Field(None) diff --git a/taskingai/client/models/schemas/action_list_response.py b/taskingai/client/models/schemas/action_list_response.py deleted file mode 100644 index 1123acd..0000000 --- a/taskingai/client/models/schemas/action_list_response.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_list_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import List -from ..entities.action import Action - -__all__ = ["ActionListResponse"] - - -class ActionListResponse(BaseModel): - status: str = Field("success") - data: List[Action] = Field(...) - fetched_count: int = Field(...) - has_more: bool = Field(...) diff --git a/taskingai/client/models/schemas/action_run_request.py b/taskingai/client/models/schemas/action_run_request.py deleted file mode 100644 index f1466c9..0000000 --- a/taskingai/client/models/schemas/action_run_request.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_run_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional, Any, Dict - - -__all__ = ["ActionRunRequest"] - - -class ActionRunRequest(BaseModel): - parameters: Optional[Dict[str, Any]] = Field(None) diff --git a/taskingai/client/models/schemas/action_run_response.py b/taskingai/client/models/schemas/action_run_response.py deleted file mode 100644 index 605dc47..0000000 --- a/taskingai/client/models/schemas/action_run_response.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_run_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Dict - - -__all__ = ["ActionRunResponse"] - - -class ActionRunResponse(BaseModel): - status: str = Field("success") - data: Dict = Field(...) diff --git a/taskingai/client/models/schemas/action_update_request.py b/taskingai/client/models/schemas/action_update_request.py deleted file mode 100644 index 0774bf8..0000000 --- a/taskingai/client/models/schemas/action_update_request.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# action_update_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional, Any, Dict -from ..entities.action_authentication import ActionAuthentication - -__all__ = ["ActionUpdateRequest"] - - -class ActionUpdateRequest(BaseModel): - openapi_schema: Optional[Dict[str, Any]] = Field(None) - authentication: Optional[ActionAuthentication] = Field(None) diff --git a/taskingai/client/models/schemas/action_update_response.py b/taskingai/client/models/schemas/chat_clean_context_response.py similarity index 60% rename from taskingai/client/models/schemas/action_update_response.py rename to taskingai/client/models/schemas/chat_clean_context_response.py index 54a56b7..01d7ebf 100644 --- a/taskingai/client/models/schemas/action_update_response.py +++ b/taskingai/client/models/schemas/chat_clean_context_response.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# action_update_response.py +# chat_completion_response.py """ This script is automatically generated for TaskingAI python client @@ -12,11 +12,11 @@ """ from pydantic import BaseModel, Field -from ..entities.action import Action +from ..entities.message import Message -__all__ = ["ActionUpdateResponse"] +__all__ = ["ChatCleanContextResponse"] -class ActionUpdateResponse(BaseModel): +class ChatCleanContextResponse(BaseModel): status: str = Field("success") - data: Action = Field(...) + data: Message = Field(...) diff --git a/taskingai/client/models/schemas/chunk_create_request.py b/taskingai/client/models/schemas/chunk_create_request.py deleted file mode 100644 index e599c62..0000000 --- a/taskingai/client/models/schemas/chunk_create_request.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_create_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Dict - - -__all__ = ["ChunkCreateRequest"] - - -class ChunkCreateRequest(BaseModel): - content: str = Field(...) - metadata: Dict = Field({}) diff --git a/taskingai/client/models/schemas/chunk_create_response.py b/taskingai/client/models/schemas/chunk_create_response.py deleted file mode 100644 index cba6ac8..0000000 --- a/taskingai/client/models/schemas/chunk_create_response.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_create_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from ..entities.chunk import Chunk - -__all__ = ["ChunkCreateResponse"] - - -class ChunkCreateResponse(BaseModel): - status: str = Field("success") - data: Chunk = Field(...) diff --git a/taskingai/client/models/schemas/chunk_get_response.py b/taskingai/client/models/schemas/chunk_get_response.py deleted file mode 100644 index 3d1afeb..0000000 --- a/taskingai/client/models/schemas/chunk_get_response.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_get_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from ..entities.chunk import Chunk - -__all__ = ["ChunkGetResponse"] - - -class ChunkGetResponse(BaseModel): - status: str = Field("success") - data: Chunk = Field(...) diff --git a/taskingai/client/models/schemas/chunk_list_request.py b/taskingai/client/models/schemas/chunk_list_request.py deleted file mode 100644 index aa2348f..0000000 --- a/taskingai/client/models/schemas/chunk_list_request.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_list_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional -from ..entities.sort_order_enum import SortOrderEnum - -__all__ = ["ChunkListRequest"] - - -class ChunkListRequest(BaseModel): - limit: int = Field(20) - order: Optional[SortOrderEnum] = Field(None) - after: Optional[str] = Field(None) - before: Optional[str] = Field(None) diff --git a/taskingai/client/models/schemas/chunk_list_response.py b/taskingai/client/models/schemas/chunk_list_response.py deleted file mode 100644 index c0fd794..0000000 --- a/taskingai/client/models/schemas/chunk_list_response.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_list_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import List -from ..entities.chunk import Chunk - -__all__ = ["ChunkListResponse"] - - -class ChunkListResponse(BaseModel): - status: str = Field("success") - data: List[Chunk] = Field(...) - fetched_count: int = Field(...) - has_more: bool = Field(...) diff --git a/taskingai/client/models/schemas/chunk_update_request.py b/taskingai/client/models/schemas/chunk_update_request.py deleted file mode 100644 index 44bfa29..0000000 --- a/taskingai/client/models/schemas/chunk_update_request.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_update_request.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from typing import Optional, Dict - - -__all__ = ["ChunkUpdateRequest"] - - -class ChunkUpdateRequest(BaseModel): - content: Optional[str] = Field(None) - metadata: Optional[Dict] = Field(None) diff --git a/taskingai/client/models/schemas/chunk_update_response.py b/taskingai/client/models/schemas/chunk_update_response.py deleted file mode 100644 index 5cc6db8..0000000 --- a/taskingai/client/models/schemas/chunk_update_response.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -# chunk_update_response.py - -""" -This script is automatically generated for TaskingAI python client -Do not modify the file manually - -Author: James Yao -Organization: TaskingAI -License: Apache 2.0 -""" - -from pydantic import BaseModel, Field -from ..entities.chunk import Chunk - -__all__ = ["ChunkUpdateResponse"] - - -class ChunkUpdateResponse(BaseModel): - status: str = Field("success") - data: Chunk = Field(...) diff --git a/taskingai/client/models/schemas/collection_create_request.py b/taskingai/client/models/schemas/collection_create_request.py index 285756a..262a2e2 100644 --- a/taskingai/client/models/schemas/collection_create_request.py +++ b/taskingai/client/models/schemas/collection_create_request.py @@ -14,12 +14,15 @@ from pydantic import BaseModel, Field from typing import Dict +from ..entities.collection import CollectionType + __all__ = ["CollectionCreateRequest"] class CollectionCreateRequest(BaseModel): name: str = Field("") + type: CollectionType = Field(CollectionType.TEXT) description: str = Field("") capacity: int = Field(1000) embedding_model_id: str = Field(...) diff --git a/taskingai/client/models/schemas/record_create_request.py b/taskingai/client/models/schemas/record_create_request.py index 5e5d5fb..4a358e7 100644 --- a/taskingai/client/models/schemas/record_create_request.py +++ b/taskingai/client/models/schemas/record_create_request.py @@ -25,5 +25,5 @@ class RecordCreateRequest(BaseModel): url: Optional[str] = Field(None, min_length=1, max_length=2048) title: str = Field("", min_length=0, max_length=256) content: Optional[str] = Field(None, min_length=1, max_length=32768) - text_splitter: TextSplitter = Field(...) + text_splitter: Optional[TextSplitter] = Field(None) metadata: Dict[str, str] = Field({}, min_length=0, max_length=16) diff --git a/taskingai/client/rest.py b/taskingai/client/rest.py index a0fbf06..176fd53 100644 --- a/taskingai/client/rest.py +++ b/taskingai/client/rest.py @@ -139,11 +139,11 @@ def request( elif headers["Content-Type"] == "multipart/form-data": # In the case of multipart, we leave it to httpx to encode the files and data request_content = post_params - elif body: + elif body is not None: if "Content-Type" not in headers: headers["Content-Type"] = "application/json" - if body is not None: - request_content = json.dumps(body) + # if body is not None: + request_content = json.dumps(body) elif files: request_files = files @@ -395,11 +395,11 @@ async def request( elif headers["Content-Type"] == "multipart/form-data": # In the case of multipart, we leave it to httpx to encode the files and data request_content = post_params - elif body: + elif body is not None: if "Content-Type" not in headers: headers["Content-Type"] = "application/json" - if body is not None: - request_content = json.dumps(body) + # if body is not None: + request_content = json.dumps(body) elif files: request_files = files diff --git a/taskingai/retrieval/chunk.py b/taskingai/retrieval/chunk.py index 6834d19..3b8749d 100644 --- a/taskingai/retrieval/chunk.py +++ b/taskingai/retrieval/chunk.py @@ -1,248 +1,15 @@ -from typing import List, Optional, Dict +from typing import List, Optional from taskingai.client.models import * from taskingai.client.apis import * __all__ = [ "Chunk", - "list_chunks", - "a_list_chunks", - "get_chunk", - "a_get_chunk", - "create_chunk", - "a_create_chunk", - "update_chunk", - "a_update_chunk", - "delete_chunk", - "a_delete_chunk", "query_chunks", "a_query_chunks", ] -def list_chunks( - collection_id: str, - *, - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, -) -> List[Chunk]: - """ - List chunks. - - :param collection_id: The ID of the collection. - :param order: The order of the chunks. It can be "asc" or "desc". - :param limit: The maximum number of assistants to return. - :param after: The cursor to get the next page of chunks. - :param before: The cursor to get the previous page of chunks. - :return: The list of chunks. - """ - - if after and before: - raise ValueError("Only one of after and before can be specified.") - - # only add non-None parameters - payload = ChunkListRequest( - order=order, - limit=limit, - after=after, - before=before, - ) - response: ChunkListResponse = api_list_chunks(collection_id=collection_id, payload=payload) - return response.data - - -async def a_list_chunks( - collection_id: str, - *, - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, -) -> List[Chunk]: - """ - List chunks in async mode. - :param collection_id: The ID of the collection. - :param order: The order of the chunks. It can be "asc" or "desc". - :param limit: The maximum number of chunks to return. - :param after: The cursor to get the next page of chunks. - :param before: The cursor to get the previous page of chunks. - :return: The list of chunks. - """ - - if after and before: - raise ValueError("Only one of after and before can be specified.") - - # only add non-None parameters - payload = ChunkListRequest( - order=order, - limit=limit, - after=after, - before=before, - ) - response: ChunkListResponse = await async_api_list_chunks(collection_id=collection_id, payload=payload) - return response.data - - -def get_chunk(collection_id: str, chunk_id: str) -> Chunk: - """ - Get a chunk. - - :param collection_id: The ID of the collection. - :param chunk_id: The ID of the chunk. - """ - - response: ChunkGetResponse = api_get_chunk( - collection_id=collection_id, - chunk_id=chunk_id, - ) - return response.data - - -async def a_get_chunk(collection_id: str, chunk_id: str) -> Chunk: - """ - Get a chunk in async mode. - - :param collection_id: The ID of the collection. - :param chunk_id: The ID of the chunk. - """ - - response: ChunkGetResponse = await async_api_get_chunk( - collection_id=collection_id, - chunk_id=chunk_id, - ) - return response.data - - -def create_chunk( - collection_id: str, - *, - content: str, - metadata: Optional[Dict[str, str]] = None, -) -> Chunk: - """ - Create a chunk. - - :param collection_id: The ID of the collection. - :param content: The content of the chunk. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. - :return: The created chunk object. - """ - - body = ChunkCreateRequest( - content=content, - metadata=metadata or {}, - ) - response: ChunkCreateResponse = api_create_chunk(collection_id=collection_id, payload=body) - return response.data - - -async def a_create_chunk( - collection_id: str, - *, - content: str, - metadata: Optional[Dict[str, str]] = None, -) -> Chunk: - """ - Create a chunk in async mode. - - :param collection_id: The ID of the collection. - :param content: The content of the chunk. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. - :return: The created chunk object. - """ - - body = ChunkCreateRequest( - content=content, - metadata=metadata or {}, - ) - response: ChunkCreateResponse = await async_api_create_chunk(collection_id=collection_id, payload=body) - return response.data - - -def update_chunk( - collection_id: str, - chunk_id: str, - *, - content: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, -) -> Chunk: - """ - Update a chunk. - - :param collection_id: The ID of the collection. - :param chunk_id: The ID of the chunk. - :param content: The content of the chunk. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less - than 64 and value's length is less than 512. - :return: The collection object. - """ - - body = ChunkUpdateRequest( - content=content, - metadata=metadata, - ) - response: ChunkUpdateResponse = api_update_chunk(collection_id=collection_id, chunk_id=chunk_id, payload=body) - return response.data - - -async def a_update_chunk( - collection_id: str, - chunk_id: str, - *, - content: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, -) -> Chunk: - """ - Update a chunk in async mode. - - :param collection_id: The ID of the collection. - :param chunk_id: The ID of the chunk. - :param content: The content of the chunk. - :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less - than 64 and value's length is less than 512. - :return: The collection object. - """ - - body = ChunkUpdateRequest( - content=content, - metadata=metadata, - ) - response: ChunkUpdateResponse = await async_api_update_chunk( - collection_id=collection_id, chunk_id=chunk_id, payload=body - ) - return response.data - - -def delete_chunk( - collection_id: str, - chunk_id: str, -) -> None: - """ - Delete a chunk. - - :param collection_id: The ID of the collection. - :param chunk_id: The ID of the chunk. - """ - - api_delete_chunk(collection_id=collection_id, chunk_id=chunk_id) - - -async def a_delete_chunk( - collection_id: str, - chunk_id: str, -) -> None: - """ - Delete a chunk in async mode. - - :param collection_id: The ID of the collection. - :param chunk_id: The ID of the chunk. - """ - - await async_api_delete_chunk(collection_id=collection_id, chunk_id=chunk_id) - - def query_chunks( collection_id: str, *, @@ -256,6 +23,7 @@ def query_chunks( :param collection_id: The ID of the collection. :param query_text: The query text. :param top_k: The number of most relevant chunks to return. + :param score_threshold: The minimum score threshold to return. :param max_tokens: The maximum number of tokens to return. """ @@ -286,6 +54,7 @@ async def a_query_chunks( :param collection_id: The ID of the collection. :param query_text: The query text. :param top_k: The number of most relevant chunks to return. + :param score_threshold: The minimum score threshold to return. :param max_tokens: The maximum number of tokens to return. """ diff --git a/taskingai/retrieval/collection.py b/taskingai/retrieval/collection.py index 3f88693..6eb723f 100644 --- a/taskingai/retrieval/collection.py +++ b/taskingai/retrieval/collection.py @@ -107,6 +107,7 @@ def create_collection( embedding_model_id: str, capacity: int = 1000, name: Optional[str] = None, + type: Optional[CollectionType] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Collection: @@ -116,6 +117,7 @@ def create_collection( :param embedding_model_id: The ID of an available embedding model in the project. :param capacity: The maximum number of embeddings that can be stored in the collection. :param name: The name of the collection. + :param type: The type of the collection. :param description: The description of the collection. :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created collection object. @@ -125,6 +127,7 @@ def create_collection( embedding_model_id=embedding_model_id, capacity=capacity, name=name or "", + type=type or CollectionType.TEXT, description=description or "", metadata=metadata or {}, ) @@ -137,6 +140,7 @@ async def a_create_collection( embedding_model_id: str, capacity: int = 1000, name: Optional[str] = None, + type: Optional[CollectionType] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, ) -> Collection: @@ -146,16 +150,17 @@ async def a_create_collection( :param embedding_model_id: The ID of an available embedding model in the project. :param capacity: The maximum number of embeddings that can be stored in the collection. :param name: The name of the collection. + :param type: The type of the collection. :param description: The description of the collection. :param metadata: The collection metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created collection object. """ - # todo verify parameters body = CollectionCreateRequest( embedding_model_id=embedding_model_id, capacity=capacity, name=name or "", + type=type or CollectionType.TEXT, description=description or "", metadata=metadata or {}, ) diff --git a/taskingai/retrieval/record.py b/taskingai/retrieval/record.py index 094f71e..d5fde94 100644 --- a/taskingai/retrieval/record.py +++ b/taskingai/retrieval/record.py @@ -137,7 +137,7 @@ def create_record( collection_id: str, *, type: Union[RecordType, str], - text_splitter: Union[TextSplitter, Dict[str, Any]], + text_splitter: Optional[Union[TextSplitter, Dict[str, Any]]] = None, title: Optional[str] = None, content: Optional[str] = None, file_id: Optional[str] = None, @@ -158,7 +158,8 @@ def create_record( :return: The created record object. """ type = _validate_record_type(type, content, file_id, url) - text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + if text_splitter: + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) body = RecordCreateRequest( title=title or "", @@ -177,7 +178,7 @@ async def a_create_record( collection_id: str, *, type: Union[RecordType, str], - text_splitter: Union[TextSplitter, Dict[str, Any]], + text_splitter: Optional[Union[TextSplitter, Dict[str, Any]]] = None, title: Optional[str] = None, content: Optional[str] = None, file_id: Optional[str] = None, @@ -199,7 +200,8 @@ async def a_create_record( """ type = _validate_record_type(type, content, file_id, url) - text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) + if text_splitter: + text_splitter = text_splitter if isinstance(text_splitter, TextSplitter) else TextSplitter(**text_splitter) body = RecordCreateRequest( title=title or "", diff --git a/taskingai/tool/__init__.py b/taskingai/tool/__init__.py deleted file mode 100644 index 9625676..0000000 --- a/taskingai/tool/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .action import * diff --git a/taskingai/tool/action.py b/taskingai/tool/action.py deleted file mode 100644 index 83b4384..0000000 --- a/taskingai/tool/action.py +++ /dev/null @@ -1,277 +0,0 @@ -from typing import Any, Optional, List, Dict, Union - -from taskingai.client.models import * -from taskingai.client.apis import * - - -__all__ = [ - "Action", - "ActionAuthentication", - "ActionAuthenticationType", - "get_action", - "list_actions", - "bulk_create_actions", - "update_action", - "delete_action", - "run_action", - "a_get_action", - "a_list_actions", - "a_bulk_create_actions", - "a_update_action", - "a_delete_action", - "a_run_action", -] - - -def list_actions( - *, - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, -) -> List[Action]: - """ - List actions. - - :param order: The order of the actions. It can be "asc" or "desc". - :param limit: The maximum number of actions to return. - :param after: The cursor to get the next page of actions. - :param before: The cursor to get the previous page of actions. - :return: The list of actions. - """ - if after and before: - raise ValueError("Only one of after and before can be specified.") - - # only add non-None parameters - payload = ActionListRequest( - order=order, - limit=limit, - after=after, - before=before, - ) - response: ActionListResponse = api_list_actions(payload=payload) - return response.data - - -async def a_list_actions( - *, - order: str = "desc", - limit: int = 20, - after: Optional[str] = None, - before: Optional[str] = None, -) -> List[Action]: - """ - List actions in async mode. - - :param order: The order of the actions. It can be "asc" or "desc". - :param limit: The maximum number of actions to return. - :param after: The cursor to get the next page of actions. - :param before: The cursor to get the previous page of actions. - :return: The list of actions. - """ - if after and before: - raise ValueError("Only one of after and before can be specified.") - - # only add non-None parameters - payload = ActionListRequest( - order=order, - limit=limit, - after=after, - before=before, - ) - response: ActionListResponse = await async_api_list_actions(payload) - return response.data - - -def get_action(action_id: str) -> Action: - """ - Get an action. - - :param action_id: The ID of the action. - """ - - response: ActionGetResponse = api_get_action(action_id=action_id) - return response.data - - -async def a_get_action(action_id: str) -> Action: - """ - Get an action in async mode. - - :param action_id: The ID of the action. - """ - - response: ActionGetResponse = await async_api_get_action(action_id=action_id) - return response.data - - -def bulk_create_actions( - *, - openapi_schema: Dict, - authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, -) -> List[Action]: - """ - Create actions from an OpenAPI schema. - - :param openapi_schema: The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the openapi_schema, the service will create multiple actions whose openapi_schema only has exactly one path and one method - :param authentication: The action API authentication. - :return: The created action object. - """ - - authentication = ( - authentication - if isinstance(authentication, ActionAuthentication) - else ActionAuthentication(**(authentication or ActionAuthentication(type=ActionAuthenticationType.NONE))) - ) - - body = ActionBulkCreateRequest( - openapi_schema=openapi_schema, - authentication=authentication, - ) - response: ActionBulkCreateResponse = api_bulk_create_actions(payload=body) - return response.data - - -async def a_bulk_create_actions( - *, - openapi_schema: Dict, - authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, -) -> List[Action]: - """ - Create actions from an OpenAPI schema in async mode. - - :param openapi_schema: The action schema is compliant with the OpenAPI Specification. If there are multiple paths and methods in the openapi_schema, the service will create multiple actions whose openapi_schema only has exactly one path and one method - :param authentication: The action API authentication. - :return: The created action object. - """ - - authentication = ( - authentication - if isinstance(authentication, ActionAuthentication) - else ActionAuthentication(**(authentication or ActionAuthentication(type=ActionAuthenticationType.NONE))) - ) - - body = ActionBulkCreateRequest( - openapi_schema=openapi_schema, - authentication=authentication, - ) - response: ActionBulkCreateResponse = await async_api_bulk_create_actions(payload=body) - return response.data - - -def update_action( - action_id: str, - *, - openapi_schema: Optional[Dict] = None, - authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, -) -> Action: - """ - Update an action. - - :param action_id: The ID of the action. - :param openapi_schema: The action schema, which is compliant with the OpenAPI Specification. It should only have exactly one path and one method. - :param authentication: The action API authentication. - :return: The updated action object. - """ - if authentication: - authentication = ( - authentication - if isinstance(authentication, ActionAuthentication) - else ActionAuthentication(**authentication) - ) - body = ActionUpdateRequest( - openapi_schema=openapi_schema, - authentication=authentication, - ) - response: ActionUpdateResponse = api_update_action(action_id=action_id, payload=body) - return response.data - - -async def a_update_action( - action_id: str, - *, - openapi_schema: Optional[Dict] = None, - authentication: Optional[Union[ActionAuthentication, Dict[str, Any]]] = None, -) -> Action: - """ - Update an action in async mode. - - :param action_id: The ID of the action. - :param openapi_schema: The action schema, which is compliant with the OpenAPI Specification. It should only have exactly one path and one method. - :param authentication: The action API authentication. - :return: The updated action object. - """ - if authentication: - authentication = ( - authentication - if isinstance(authentication, ActionAuthentication) - else ActionAuthentication(**authentication) - ) - body = ActionUpdateRequest( - openapi_schema=openapi_schema, - authentication=authentication, - ) - response: ActionUpdateResponse = await async_api_update_action(action_id=action_id, payload=body) - return response.data - - -def delete_action(action_id: str) -> None: - """ - Delete an action. - - :param action_id: The ID of the action. - """ - - api_delete_action(action_id=action_id) - - -async def a_delete_action(action_id: str) -> None: - """ - Delete an action in async mode. - - :param action_id: The ID of the action. - """ - - await async_api_delete_action(action_id=action_id) - - -def run_action( - action_id: str, - *, - parameters: Dict, -) -> Dict: - """ - Run an action. - - :param action_id: The ID of the action. - :param parameters: The action parameters. - :return: The action response. - """ - - body = ActionRunRequest( - parameters=parameters, - ) - response: ActionRunResponse = api_run_action(action_id=action_id, payload=body) - result = response.data - return result - - -async def a_run_action( - action_id: str, - *, - parameters: Dict, -) -> Dict: - """ - Run an action in async mode. - - :param action_id: The ID of the action. - :param parameters: The action parameters. - :return: The action response. - """ - - body = ActionRunRequest( - parameters=parameters, - ) - response: ActionRunResponse = await async_api_run_action(action_id=action_id, payload=body) - result = response.data - return result diff --git a/test/common/utils.py b/test/common/utils.py index dacb9bf..ec0bfc9 100644 --- a/test/common/utils.py +++ b/test/common/utils.py @@ -118,7 +118,7 @@ def assume_record_result(create_record_data: dict, res_dict: dict): else: pytest.assume(res_dict[key] == create_record_data[key]) - pytest.assume(res_dict["status"] == "ready") + pytest.assume(res_dict["status"] == "creating") def assume_chunk_result(chunk_dict: dict, res: dict): diff --git a/test/config.py b/test/config.py index c2271fa..f4da7bc 100644 --- a/test/config.py +++ b/test/config.py @@ -29,4 +29,4 @@ class Config: taskingai.init(api_key=taskingai_apikey, host=taskingai_host) - sleep_time = 1 + sleep_time = 4 diff --git a/test/qa_files/Q&A.csv b/test/qa_files/Q&A.csv new file mode 100644 index 0000000..0bf0956 --- /dev/null +++ b/test/qa_files/Q&A.csv @@ -0,0 +1,6 @@ +question,answer +What is NBA?,NBA is baskball sport. +Do you know how to test,"Yes, test is a very import thing" +"hello, who are you",i am an assitant +how to be success,just do it +what is AI,AI is a tool to help people diff --git a/test/qa_files/Q&A.xlsx b/test/qa_files/Q&A.xlsx new file mode 100644 index 0000000..866cb14 Binary files /dev/null and b/test/qa_files/Q&A.xlsx differ diff --git a/test/testcase/test_async/__init__.py b/test/testcase/test_async/__init__.py index 1f4fd37..06bcc80 100644 --- a/test/testcase/test_async/__init__.py +++ b/test/testcase/test_async/__init__.py @@ -1,3 +1,4 @@ class Base: - - collection_id, record_id, chunk_id, action_id, assistant_id, chat_id, message_id = None, None, None, None, None, None, None \ No newline at end of file + collection_id, record_id, chunk_id, assistant_id, chat_id, message_id = None, None, None, None, None, None + qa_collection_id, qa_record_id = None, None + # action_id = "cuPLXedFSRbWGK0VfNI4VZBc" diff --git a/test/testcase/test_async/conftest.py b/test/testcase/test_async/conftest.py index e4e7f36..1c8d0a5 100644 --- a/test/testcase/test_async/conftest.py +++ b/test/testcase/test_async/conftest.py @@ -1,6 +1,5 @@ from taskingai.assistant import a_list_assistants, a_list_chats, a_list_messages from taskingai.retrieval import a_list_collections, a_list_records -from taskingai.tool import a_list_actions import pytest import asyncio @@ -44,13 +43,6 @@ async def a_record_id(a_collection_id): return record_id -@pytest.fixture(scope="function") -async def a_action_id(): - res = await a_list_actions() - action_id = res[-1].action_id - return action_id - - @pytest.fixture(scope="session") def event_loop(request): """Create an instance of the default event loop for each test case.""" diff --git a/test/testcase/test_async/test_async_assistant.py b/test/testcase/test_async/test_async_assistant.py index 7237b1b..4330dfa 100644 --- a/test/testcase/test_async/test_async_assistant.py +++ b/test/testcase/test_async/test_async_assistant.py @@ -2,30 +2,29 @@ from taskingai.assistant import * from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig -from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory +from taskingai.assistant.memory import AssistantMessageWindowMemory from test.config import Config from test.common.logger import logger -from test.common.utils import list_to_dict -from test.common.utils import assume_assistant_result, assume_chat_result, assume_message_result +from test.common.utils import assume_assistant_result from test.testcase.test_async import Base @pytest.mark.test_async class TestAssistant(Base): - @pytest.mark.run(order=51) @pytest.mark.asyncio async def test_a_create_assistant(self): - # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( @@ -33,41 +32,31 @@ async def test_a_create_assistant(self): id=self.collection_id, ), ], - "retrieval_configs": RetrievalConfig( - method="memory", - top_k=1, - max_tokens=5000, - score_threshold=0.5 - - ), + "retrieval_configs": RetrievalConfig(method="memory", top_k=1, max_tokens=5000, score_threshold=0.5), "tools": [ - ToolRef( - type=ToolType.ACTION, - id=self.action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } for i in range(4): if i == 0: - assistant_dict.update({"memory": {"type": "naive"}}) + assistant_dict.update({"memory": {"type": "message_window", "max_messages": 50, "max_tokens": 2000}}) assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]}) - assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}}) - assistant_dict.update({"tools": [{"type": "action", "id": self.action_id}, - {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) + assistant_dict.update( + {"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}} + ) + assistant_dict.update({"tools": [{"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) res = await a_create_assistant(**assistant_dict) res_dict = vars(res) - logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, res_dict) Base.assistant_id = res.assistant_id @pytest.mark.run(order=52) @pytest.mark.asyncio async def test_a_list_assistants(self): - # List assistants. nums_limit = 1 @@ -92,7 +81,6 @@ async def test_a_list_assistants(self): @pytest.mark.run(order=53) @pytest.mark.asyncio async def test_a_get_assistant(self): - # Get an assistant. res = await a_get_assistant(assistant_id=self.assistant_id) @@ -102,59 +90,45 @@ async def test_a_get_assistant(self): @pytest.mark.run(order=54) @pytest.mark.asyncio async def test_a_update_assistant(self): - # Update an assistant. update_data_list = [ { "name": "openai", "description": "test for openai", - "memory": AssistantZeroMemory(), + "memory": AssistantMessageWindowMemory(max_tokens=2000), "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, id=self.collection_id, ), ], - "retrieval_configs": RetrievalConfig( - method="memory", - top_k=2, - max_tokens=4000, - score_threshold=0.5 - - ), + "retrieval_configs": RetrievalConfig(method="memory", top_k=2, max_tokens=4000, score_threshold=0.5), "tools": [ - ToolRef( - type=ToolType.ACTION, - id=self.action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], }, { "name": "openai", "description": "test for openai", - "memory": {"type": "naive"}, + "memory": {"type": "message_window", "max_messages": 50, "max_tokens": 2000}, "retrievals": [{"type": "collection", "id": self.collection_id}], "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}, - "tools": [{"type": "action", "id": self.action_id}, - {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] - - } + "tools": [{"type": "plugin", "id": "open_weather/get_hourly_forecast"}], + }, ] for update_data in update_data_list: res = await a_update_assistant(assistant_id=self.assistant_id, **update_data) res_dict = vars(res) - logger.info(f'response_dict:{res_dict}, except_dict:{update_data}') + logger.info(f"response_dict:{res_dict}, except_dict:{update_data}") assume_assistant_result(update_data, res_dict) @pytest.mark.run(order=66) @pytest.mark.asyncio async def test_a_delete_assistant(self): - # List assistants. assistants = await a_list_assistants(limit=100) @@ -175,99 +149,89 @@ async def test_a_delete_assistant(self): @pytest.mark.test_async class TestChat(Base): + @pytest.mark.run(order=55) + @pytest.mark.asyncio + async def test_a_create_chat(self): + for x in range(2): + # Create a chat. + name = f"test_chat{x + 1}" + res = await a_create_chat(assistant_id=self.assistant_id, name=name) + res_dict = vars(res) + pytest.assume(res_dict["name"] == name) + Base.chat_id = res.chat_id - @pytest.mark.run(order=55) - @pytest.mark.asyncio - async def test_a_create_chat(self): - - for x in range(2): - - # Create a chat. - name = f"test_chat{x + 1}" - res = await a_create_chat(assistant_id=self.assistant_id, name=name) - res_dict = vars(res) - pytest.assume(res_dict["name"] == name) - Base.chat_id = res.chat_id - - @pytest.mark.run(order=56) - @pytest.mark.asyncio - async def test_a_list_chats(self): + @pytest.mark.run(order=56) + @pytest.mark.asyncio + async def test_a_list_chats(self): + # List chats. - # List chats. + nums_limit = 1 + res = await a_list_chats(limit=nums_limit, assistant_id=self.assistant_id) + pytest.assume(len(res) == nums_limit) - nums_limit = 1 - res = await a_list_chats(limit=nums_limit, assistant_id=self.assistant_id) - pytest.assume(len(res) == nums_limit) + after_id = res[-1].chat_id + after_res = await a_list_chats(limit=nums_limit, after=after_id, assistant_id=self.assistant_id) + pytest.assume(len(after_res) == nums_limit) - after_id = res[-1].chat_id - after_res = await a_list_chats(limit=nums_limit, after=after_id, assistant_id=self.assistant_id) - pytest.assume(len(after_res) == nums_limit) + twice_nums_list = await a_list_chats(limit=nums_limit * 2, assistant_id=self.assistant_id) + pytest.assume(len(twice_nums_list) == nums_limit * 2) + pytest.assume(after_res[-1] == twice_nums_list[-1]) + pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - twice_nums_list = await a_list_chats(limit=nums_limit * 2, assistant_id=self.assistant_id) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) + before_id = after_res[0].chat_id + before_res = await a_list_chats(limit=nums_limit, before=before_id, assistant_id=self.assistant_id) + pytest.assume(len(before_res) == nums_limit) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) + pytest.assume(before_res[0] == twice_nums_list[0]) - before_id = after_res[0].chat_id - before_res = await a_list_chats(limit=nums_limit, before=before_id, assistant_id=self.assistant_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) + @pytest.mark.run(order=57) + @pytest.mark.asyncio + async def test_a_get_chat(self): + # Get a chat. + res = await a_get_chat(assistant_id=self.assistant_id, chat_id=self.chat_id) + res_dict = vars(res) + pytest.assume(res_dict["chat_id"] == self.chat_id) + pytest.assume(res_dict["assistant_id"] == self.assistant_id) - @pytest.mark.run(order=57) - @pytest.mark.asyncio - async def test_a_get_chat(self): + @pytest.mark.run(order=58) + @pytest.mark.asyncio + async def test_a_update_chat(self): + # Update a chat. - # Get a chat. - res = await a_get_chat(assistant_id=self.assistant_id, chat_id=self.chat_id) - res_dict = vars(res) - pytest.assume(res_dict["chat_id"] == self.chat_id) - pytest.assume(res_dict["assistant_id"] == self.assistant_id) + metadata = {"test": "test"} + name = "test_update_chat" + res = await a_update_chat(assistant_id=self.assistant_id, chat_id=self.chat_id, metadata=metadata, name=name) + res_dict = vars(res) + pytest.assume(res_dict["metadata"] == metadata) + pytest.assume(res_dict["name"] == name) - @pytest.mark.run(order=58) - @pytest.mark.asyncio - async def test_a_update_chat(self): + @pytest.mark.run(order=65) + @pytest.mark.asyncio + async def test_a_delete_chat(self): + # List chats. - # Update a chat. + chats = await a_list_chats(assistant_id=self.assistant_id) + old_nums = len(chats) + for index, chat in enumerate(chats): + chat_id = chat.chat_id - metadata = {"test": "test"} - name = "test_update_chat" - res = await a_update_chat(assistant_id=self.assistant_id, chat_id=self.chat_id, metadata=metadata, name=name) - res_dict = vars(res) - pytest.assume(res_dict["metadata"] == metadata) - pytest.assume(res_dict["name"] == name) + # Delete a chat. - @pytest.mark.run(order=65) - @pytest.mark.asyncio - async def test_a_delete_chat(self): + await a_delete_chat(assistant_id=self.assistant_id, chat_id=str(chat_id)) # List chats. - - chats = await a_list_chats(assistant_id=self.assistant_id) - old_nums = len(chats) - for index, chat in enumerate(chats): - chat_id = chat.chat_id - - # Delete a chat. - - await a_delete_chat(assistant_id=self.assistant_id, chat_id=str(chat_id)) - - # List chats. - if index == old_nums - 1: - new_chats = await a_list_chats(assistant_id=self.assistant_id) - new_nums = len(new_chats) - pytest.assume(new_nums == 0) + if index == old_nums - 1: + new_chats = await a_list_chats(assistant_id=self.assistant_id) + new_nums = len(new_chats) + pytest.assume(new_nums == 0) @pytest.mark.test_async class TestMessage(Base): - @pytest.mark.run(order=59) @pytest.mark.asyncio async def test_a_create_message(self): - for x in range(2): - # Create a user message. text = "hello, what is the weather like in HongKong?" @@ -281,7 +245,6 @@ async def test_a_create_message(self): @pytest.mark.run(order=60) @pytest.mark.asyncio async def test_a_list_messages(self): - # List messages. nums_limit = 1 @@ -289,19 +252,22 @@ async def test_a_list_messages(self): pytest.assume(len(res) == nums_limit) after_id = res[-1].message_id - after_res = await a_list_messages(limit=nums_limit, after=after_id, assistant_id=self.assistant_id, - chat_id=self.chat_id) + after_res = await a_list_messages( + limit=nums_limit, after=after_id, assistant_id=self.assistant_id, chat_id=self.chat_id + ) pytest.assume(len(after_res) == nums_limit) - twice_nums_list = await a_list_messages(limit=nums_limit * 2, assistant_id=self.assistant_id, - chat_id=self.chat_id) + twice_nums_list = await a_list_messages( + limit=nums_limit * 2, assistant_id=self.assistant_id, chat_id=self.chat_id + ) pytest.assume(len(twice_nums_list) == nums_limit * 2) pytest.assume(after_res[-1] == twice_nums_list[-1]) pytest.assume(after_res[0] == twice_nums_list[nums_limit]) before_id = after_res[0].message_id - before_res = await a_list_messages(limit=nums_limit, before=before_id, assistant_id=self.assistant_id, - chat_id=self.chat_id) + before_res = await a_list_messages( + limit=nums_limit, before=before_id, assistant_id=self.assistant_id, chat_id=self.chat_id + ) pytest.assume(len(before_res) == nums_limit) pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) @@ -309,7 +275,6 @@ async def test_a_list_messages(self): @pytest.mark.run(order=61) @pytest.mark.asyncio async def test_a_get_message(self): - # Get a message. res = await a_get_message(assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id) @@ -321,23 +286,21 @@ async def test_a_get_message(self): @pytest.mark.run(order=62) @pytest.mark.asyncio async def test_a_update_message(self): - # Update a message. metadata = {"test": "test"} - res = await a_update_message(assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id, - metadata=metadata) + res = await a_update_message( + assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id, metadata=metadata + ) res_dict = vars(res) pytest.assume(res_dict["metadata"] == metadata) @pytest.mark.run(order=63) @pytest.mark.asyncio async def test_a_generate_message(self): - # Generate an assistant message. - res = await a_generate_message(assistant_id=self.assistant_id, chat_id=self.chat_id, - system_prompt_variables={}) + res = await a_generate_message(assistant_id=self.assistant_id, chat_id=self.chat_id, system_prompt_variables={}) res_dict = vars(res) pytest.assume(res_dict["role"] == "assistant") pytest.assume(res_dict["content"] is not None) @@ -347,15 +310,29 @@ async def test_a_generate_message(self): @pytest.mark.run(order=64) @pytest.mark.asyncio - async def test_a_generate_message_by_stream(self): + async def test_a_clean_chat_context(self): + # Generate an assistant message by no stream. + + res = await a_clean_chat_context(assistant_id=self.assistant_id, chat_id=self.chat_id) + res_dict = vars(res) + pytest.assume(res_dict["role"] == "system") + pytest.assume(res_dict["content"] is not None) + pytest.assume(res_dict["assistant_id"] == self.assistant_id) + pytest.assume(res_dict["chat_id"] == self.chat_id) + pytest.assume(vars(res_dict["content"])["text"] == "context_cleaned") + @pytest.mark.run(order=64) + @pytest.mark.asyncio + async def test_a_generate_message_by_stream(self): assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( @@ -363,23 +340,13 @@ async def test_a_generate_message_by_stream(self): id=self.collection_id, ), ], - "retrieval_configs": RetrievalConfig( - method="memory", - top_k=1, - max_tokens=5000, - score_threshold=0.04 - - ), + "retrieval_configs": RetrievalConfig(method="memory", top_k=1, max_tokens=5000, score_threshold=0.04), "tools": [ - ToolRef( - type=ToolType.ACTION, - id=self.action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } assistant_res = await a_create_assistant(**assistant_dict) assistant_id = assistant_res.assistant_id @@ -387,7 +354,7 @@ async def test_a_generate_message_by_stream(self): chat_res = await a_create_chat(assistant_id=assistant_id, name="test_chat") chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') + logger.info(f"chat_id:{chat_id}") # create user message @@ -399,8 +366,9 @@ async def test_a_generate_message_by_stream(self): # Generate an assistant message by stream. - stream_res = await a_generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}, stream=True) + stream_res = await a_generate_message( + assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True + ) except_list = ["MessageChunk", "Message"] real_list = [] async for item in stream_res: @@ -418,16 +386,17 @@ async def test_a_generate_message_by_stream(self): @pytest.mark.run(order=70) @pytest.mark.asyncio async def test_a_assistant_by_user_message_retrieval_and_stream(self): - # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( @@ -435,24 +404,23 @@ async def test_a_assistant_by_user_message_retrieval_and_stream(self): id=self.collection_id, ), ], - "retrieval_configs": { - "method": "user_message", - "top_k": 1, - "max_tokens": 5000, - "score_threshold": 0.5 - } + "retrieval_configs": {"method": "user_message", "top_k": 1, "max_tokens": 5000, "score_threshold": 0.5}, } assistant_res = await a_create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text) - generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True) - final_content = '' + create_message_res = await a_create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = await a_generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" async for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -466,16 +434,17 @@ async def test_a_assistant_by_user_message_retrieval_and_stream(self): @pytest.mark.run(order=70) @pytest.mark.asyncio async def test_a_assistant_by_memory_retrieval_and_stream(self): - # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( @@ -483,28 +452,23 @@ async def test_a_assistant_by_memory_retrieval_and_stream(self): id=self.collection_id, ), ], - "retrieval_configs": { - "method": "memory", - "top_k": 1, - "max_tokens": 5000, - "score_threshold": 0.5 - - } + "retrieval_configs": {"method": "memory", "top_k": 1, "max_tokens": 5000, "score_threshold": 0.5}, } assistant_res = await a_create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, text=text) - generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = await a_create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = await a_generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" async for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -518,16 +482,17 @@ async def test_a_assistant_by_memory_retrieval_and_stream(self): @pytest.mark.run(order=70) @pytest.mark.asyncio async def test_a_assistant_by_function_call_retrieval_and_stream(self): - # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( @@ -535,28 +500,23 @@ async def test_a_assistant_by_function_call_retrieval_and_stream(self): id=self.collection_id, ), ], - "retrieval_configs": - { - "method": "function_call", - "top_k": 1, - "max_tokens": 5000, - "score_threshold": 0.5 - } + "retrieval_configs": {"method": "function_call", "top_k": 1, "max_tokens": 5000, "score_threshold": 0.5}, } assistant_res = await a_create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, text=text) - generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = await a_create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = await a_generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" async for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -570,16 +530,17 @@ async def test_a_assistant_by_function_call_retrieval_and_stream(self): @pytest.mark.run(order=70) @pytest.mark.asyncio async def test_a_assistant_by_not_support_function_call_retrieval_and_stream(self): - # Create an assistant. assistant_dict = { "model_id": Config.anthropic_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( @@ -591,8 +552,7 @@ async def test_a_assistant_by_not_support_function_call_retrieval_and_stream(sel method="function_call", top_k=1, max_tokens=5000, - - ) + ), } with pytest.raises(Exception) as e: assistant_res = await a_create_assistant(**assistant_dict) @@ -601,42 +561,40 @@ async def test_a_assistant_by_not_support_function_call_retrieval_and_stream(sel @pytest.mark.run(order=70) @pytest.mark.asyncio async def test_a_assistant_by_all_tool_and_stream(self): - # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "tools": [ - ToolRef( - type=ToolType.ACTION, - id=self.action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } assistant_res = await a_create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = await a_create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = await a_create_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, text=text) - generate_message_res = await a_generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = await a_create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = await a_generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" async for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -650,30 +608,25 @@ async def test_a_assistant_by_all_tool_and_stream(self): @pytest.mark.run(order=70) @pytest.mark.asyncio async def test_a_assistant_by_not_support_function_call_tool_and_stream(self): - # Create an assistant. assistant_dict = { "model_id": Config.anthropic_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "tools": [ - ToolRef( - type=ToolType.ACTION, - id=self.action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } - with pytest.raises(Exception) as e: - assistant_res = await a_create_assistant(**assistant_dict) - assert "not support function call to use the tools" in str(e.value) - + assistant_res = await a_create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) diff --git a/test/testcase/test_async/test_async_retrieval.py b/test/testcase/test_async/test_async_retrieval.py index 277cdfd..123ecc1 100644 --- a/test/testcase/test_async/test_async_retrieval.py +++ b/test/testcase/test_async/test_async_retrieval.py @@ -22,19 +22,33 @@ class TestCollection(Base): async def test_a_create_collection(self): # Create a collection. - create_dict = { - "capacity": 1000, - "embedding_model_id": Config.openai_text_embedding_model_id, - "name": "test", - "description": "description", - "metadata": {"key1": "value1", "key2": "value2"}, - } - for x in range(2): + create_list = [ + { + "capacity": 1000, + "embedding_model_id": Config.openai_text_embedding_model_id, + "name": "test", + "description": "description", + "metadata": {"key1": "value1", "key2": "value2"}, + }, + { + "capacity": 1000, + "embedding_model_id": Config.openai_text_embedding_model_id, + "type": "qa", + "name": "test", + "description": "description", + "metadata": {"key1": "value1", "key2": "value2"}, + }, + + ] + for index, create_dict in enumerate(create_list): res = await a_create_collection(**create_dict) res_dict = vars(res) logger.info(res_dict) assume_collection_result(create_dict, res_dict) - Base.collection_id = res_dict["collection_id"] + if index == 0: + Base.collection_id = res_dict["collection_id"] + else: + Base.qa_collection_id = res_dict["collection_id"] @pytest.mark.run(order=22) @pytest.mark.asyncio @@ -108,7 +122,7 @@ class TestRecord(Base): ] upload_file_data_list = [] - + upload_qa_file_data_list = [] base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) files = os.listdir(base_path + "/files") for file in files: @@ -118,6 +132,14 @@ class TestRecord(Base): upload_file_dict.update({"file": open(filepath, "rb")}) upload_file_data_list.append(upload_file_dict) + qa_files = os.listdir(base_path + "/qa_files") + for file in qa_files: + filepath = os.path.join(base_path, "qa_files", file) + if os.path.isfile(filepath): + upload_qa_file_dict = {"purpose": "qa_record_file"} + upload_qa_file_dict.update({"file": open(filepath, "rb")}) + upload_qa_file_data_list.append(upload_qa_file_dict) + @pytest.mark.run(order=31) @pytest.mark.asyncio @pytest.mark.parametrize("text_splitter", text_splitter_list) @@ -176,6 +198,29 @@ async def test_a_create_record_by_file(self, upload_file_data): res_dict = vars(res) assume_record_result(create_record_data, res_dict) + @pytest.mark.run(order=32) + @pytest.mark.asyncio + @pytest.mark.parametrize("upload_qa_file_data", upload_qa_file_data_list) + async def test_create_record_by_qa_file(self, upload_qa_file_data): + # upload file + upload_file_res = await a_upload_file(**upload_qa_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) + create_record_data = { + "type": "qa_sheet", + "collection_id": self.qa_collection_id, + "file_id": file_id, + "metadata": {"key1": "value1", "key2": "value2"}, + } + + res = create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + Base.qa_record_id = res_dict["record_id"] + @pytest.mark.run(order=32) @pytest.mark.asyncio async def test_a_list_records(self): @@ -200,11 +245,35 @@ async def test_a_list_records(self): pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) + @pytest.mark.run(order=32) + @pytest.mark.asyncio + async def test_a_list_qa_records(self): + # List records. + + nums_limit = 1 + res = await a_list_records(limit=nums_limit, collection_id=self.qa_collection_id) + pytest.assume(len(res) == nums_limit) + + after_id = res[-1].record_id + after_res = await a_list_records(limit=nums_limit, after=after_id, collection_id=self.qa_collection_id) + pytest.assume(len(after_res) == nums_limit) + + twice_nums_list = await a_list_records(limit=nums_limit * 2, collection_id=self.qa_collection_id) + pytest.assume(len(twice_nums_list) == nums_limit * 2) + pytest.assume(after_res[-1] == twice_nums_list[-1]) + pytest.assume(after_res[0] == twice_nums_list[nums_limit]) + + before_id = after_res[0].record_id + before_res = await a_list_records(limit=nums_limit, before=before_id, collection_id=self.qa_collection_id) + pytest.assume(len(before_res) == nums_limit) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) + pytest.assume(before_res[0] == twice_nums_list[0]) + @pytest.mark.run(order=33) @pytest.mark.asyncio async def test_a_get_record(self): # Get a record. - + await asyncio.sleep(Config.sleep_time) res = await a_get_record(collection_id=self.collection_id, record_id=self.record_id) logger.info(f"a_get_record:{res}") res_dict = vars(res) @@ -212,12 +281,24 @@ async def test_a_get_record(self): pytest.assume(res_dict["record_id"] == self.record_id) pytest.assume(res_dict["status"] == "ready") + @pytest.mark.run(order=33) + @pytest.mark.asyncio + async def test_a_get_qa_record(self): + # Get a record. + await asyncio.sleep(Config.sleep_time) + res = await a_get_record(collection_id=self.qa_collection_id, record_id=self.qa_record_id) + logger.info(f"a_get_record:{res}") + res_dict = vars(res) + pytest.assume(res_dict["collection_id"] == self.qa_collection_id) + pytest.assume(res_dict["record_id"] == self.qa_record_id) + pytest.assume(res_dict["status"] == "ready") + @pytest.mark.run(order=34) @pytest.mark.asyncio @pytest.mark.parametrize("text_splitter", text_splitter_list) async def test_a_update_record_by_text(self, text_splitter): # Update a record. - + await asyncio.sleep(Config.sleep_time) update_record_data = { "collection_id": self.collection_id, "record_id": self.record_id, @@ -235,7 +316,7 @@ async def test_a_update_record_by_text(self, text_splitter): @pytest.mark.parametrize("text_splitter", text_splitter_list) async def test_a_update_record_by_web(self, text_splitter): # Update a record. - + await asyncio.sleep(Config.sleep_time) update_record_data = { "type": "web", "title": "Machine learning", @@ -255,6 +336,7 @@ async def test_a_update_record_by_web(self, text_splitter): @pytest.mark.parametrize("upload_file_data", upload_file_data_list[2:3]) async def test_a_update_record_by_file(self, upload_file_data): # upload file + await asyncio.sleep(Config.sleep_time) upload_file_res = await a_upload_file(**upload_file_data) upload_file_dict = vars(upload_file_res) file_id = upload_file_dict["file_id"] @@ -276,6 +358,31 @@ async def test_a_update_record_by_file(self, upload_file_data): res_dict = vars(res) assume_record_result(update_record_data, res_dict) + @pytest.mark.run(order=34) + @pytest.mark.asyncio + @pytest.mark.parametrize("upload_qa_file_data", upload_qa_file_data_list) + async def test_a_update_qa_record(self, upload_qa_file_data): + # upload file + await asyncio.sleep(Config.sleep_time) + upload_file_res = await a_upload_file(**upload_qa_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + # Update a record. + + update_record_data = { + "type": "qa_sheet", + "collection_id": self.qa_collection_id, + "record_id": self.qa_record_id, + "file_id": file_id, + "metadata": {"test": "test"}, + } + res = await a_update_record(**update_record_data) + logger.info(f"a_update_record:{res}") + res_dict = vars(res) + # assume_record_result(update_record_data, res_dict) + @pytest.mark.run(order=79) @pytest.mark.asyncio async def test_a_delete_record(self): @@ -302,6 +409,32 @@ async def test_a_delete_record(self): new_nums = len(new_records) pytest.assume(new_nums == 0) + @pytest.mark.run(order=79) + @pytest.mark.asyncio + async def test_a_delete_qa_record(self): + # List records. + + records = await a_list_records( + collection_id=self.qa_collection_id, order="desc", limit=20, after=None, before=None + ) + old_nums = len(records) + for index, record in enumerate(records): + record_id = record.record_id + + # Delete a record. + + await a_delete_record(collection_id=self.qa_collection_id, record_id=record_id) + + # List records. + if index == old_nums - 1: + new_records = await a_list_records( + collection_id=self.qa_collection_id, order="desc", limit=20, after=None, before=None + ) + record_ids = [record.record_id for record in new_records] + pytest.assume(record_id not in record_ids) + new_nums = len(new_records) + pytest.assume(new_nums == 0) + @pytest.mark.test_async class TestChunk(Base): @@ -321,91 +454,3 @@ async def test_a_query_chunks(self): chunk_dict = vars(chunk) assume_query_chunk_result(query_text, chunk_dict) pytest.assume(chunk_dict["score"] >= 0.04) - - @pytest.mark.run(order=42) - @pytest.mark.asyncio - async def test_create_chunk(self): - # Create a chunk. - create_chunk_data = { - "collection_id": self.collection_id, - "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", - } - res = await a_create_chunk(**create_chunk_data) - res_dict = vars(res) - assume_chunk_result(create_chunk_data, res_dict) - Base.chunk_id = res_dict["chunk_id"] - - @pytest.mark.run(order=43) - @pytest.mark.asyncio - async def test_list_chunks(self): - # List chunks. - - nums_limit = 1 - res = await a_list_chunks(limit=nums_limit, collection_id=self.collection_id) - pytest.assume(len(res) == nums_limit) - - after_id = res[-1].chunk_id - after_res = await a_list_chunks(limit=nums_limit, after=after_id, collection_id=self.collection_id) - pytest.assume(len(after_res) == nums_limit) - - twice_nums_list = await a_list_chunks(limit=nums_limit * 2, collection_id=self.collection_id) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - - before_id = after_res[0].chunk_id - before_res = await a_list_chunks(limit=nums_limit, before=before_id, collection_id=self.collection_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) - - @pytest.mark.run(order=44) - @pytest.mark.asyncio - async def test_get_chunk(self): - # list chunks - - chunks = list_chunks(collection_id=self.collection_id) - for chunk in chunks: - chunk_id = chunk.chunk_id - res = get_chunk(collection_id=self.collection_id, chunk_id=chunk_id) - logger.info(f"get chunk response: {res}") - res_dict = vars(res) - pytest.assume(res_dict["collection_id"] == self.collection_id) - pytest.assume(res_dict["chunk_id"] == chunk_id) - - @pytest.mark.run(order=45) - @pytest.mark.asyncio - async def test_update_chunk(self): - # Update a chunk. - - update_chunk_data = { - "collection_id": self.collection_id, - "chunk_id": self.chunk_id, - "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", - "metadata": {"test": "test"}, - } - res = await a_update_chunk(**update_chunk_data) - res_dict = vars(res) - assume_chunk_result(update_chunk_data, res_dict) - - @pytest.mark.run(order=46) - @pytest.mark.asyncio - async def test_delete_chunk(self): - # List chunks. - - await asyncio.sleep(Config.sleep_time) - - chunks = await a_list_chunks(collection_id=self.collection_id, limit=5) - old_nums = len(chunks) - for index, chunk in enumerate(chunks): - chunk_id = chunk.chunk_id - - # Delete a chunk. - - delete_chunk(collection_id=self.collection_id, chunk_id=chunk_id) - - # List chunks. - - new_chunks = await a_list_chunks(collection_id=self.collection_id) - chunk_ids = [chunk.chunk_id for chunk in new_chunks] - pytest.assume(chunk_id not in chunk_ids) diff --git a/test/testcase/test_async/test_async_tool.py b/test/testcase/test_async/test_async_tool.py deleted file mode 100644 index 2442e1f..0000000 --- a/test/testcase/test_async/test_async_tool.py +++ /dev/null @@ -1,227 +0,0 @@ -import pytest -from test.config import Config -from taskingai.tool import a_bulk_create_actions, a_get_action, a_update_action, a_delete_action, a_run_action, a_list_actions, ActionAuthentication, ActionAuthenticationType -from test.common.logger import logger -from test.testcase.test_async import Base - - -@pytest.mark.test_async -class TestAction(Base): - - authentication_list = [ - { - "type": "bearer", - "secret": "ASD213df" - }, - ActionAuthentication(type=ActionAuthenticationType.BEARER, secret="ASD213df") - ] - - - @pytest.mark.run(order=11) - @pytest.mark.asyncio - @pytest.mark.parametrize("authentication", authentication_list) - async def test_a_bulk_create_actions(self, authentication): - - schema = { - "openapi_schema": { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location 123", - "operationId": "GetCurrentWeather123", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - } - } - } - - } - - } - schema.update({"authentication": authentication}) - - # Create an action. - for i in range(2): - res = await a_bulk_create_actions(**schema) - for action in res: - action_dict = vars(action) - logger.info(action_dict) - for key in schema.keys(): - if key != "authentication": - for k, v in schema[key].items(): - pytest.assume(action_dict[key][k] == v) - else: - if isinstance(schema[key], ActionAuthentication): - schema[key] = vars(schema[key]) - for k, v in schema[key].items(): - if v is None: - pytest.assume(vars(action_dict[key])[k] == v) - elif k == "type": - pytest.assume(vars(action_dict[key])[k] == v) - else: - pytest.assume("*" in vars(action_dict[key])[k]) - Base.action_id = res[0].action_id - - @pytest.mark.run(order=12) - @pytest.mark.asyncio - async def test_a_list_actions(self): - - # List actions. - - nums_limit = 1 - res = await a_list_actions(limit=nums_limit) - pytest.assume(len(res) == nums_limit) - - after_id = res[-1].action_id - after_res = await a_list_actions(limit=nums_limit, after=after_id) - pytest.assume(len(after_res) == nums_limit) - - twice_nums_list = await a_list_actions(limit=nums_limit * 2) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - - before_id = after_res[0].action_id - before_res = await a_list_actions(limit=nums_limit, before=before_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit-1]) - pytest.assume(before_res[0] == twice_nums_list[0]) - - @pytest.mark.run(order=13) - @pytest.mark.asyncio - async def test_a_get_action(self): - - # Get an action. - - res = await a_get_action(action_id=self.action_id) - res_dict = vars(res) - logger.info(res_dict["openapi_schema"].keys()) - pytest.assume(res_dict["action_id"] == self.action_id) - - @pytest.mark.run(order=14) - @pytest.mark.asyncio - @pytest.mark.parametrize("authentication", authentication_list) - async def test_a_update_action(self, authentication): - - # Update an action. - - update_schema = { - "openapi_schema": { - "openapi": "3.0.0", - "info": { - "title": "Numbers API", - "version": "1.0.0", - "description": "API for fetching interesting number facts" - }, - "servers": [ - { - "url": "http://numbersapi.com" - } - ], - "paths": { - "/{number}": { - "get": { - "description": "Get fact about a number", - "operationId": "get_number_fact", - "parameters": [ - { - "name": "number", - "in": "path", - "required": True, - "description": "The number to get the fact for", - "schema": { - "type": "integer" - } - } - ], - "responses": { - "200": { - "description": "A fact about the number", - "content": { - "text/plain": { - "schema": { - "type": "string" - } - } - } - } - } - } - } - } - } - } - update_schema.update({"authentication": authentication}) - - res = await a_update_action(action_id=self.action_id, **update_schema) - res_dict = vars(res) - logger.info(res_dict) - for key in update_schema.keys(): - if key != "authentication": - for k, v in update_schema[key].items(): - pytest.assume(res_dict[key][k] == v) - else: - if isinstance(update_schema[key], ActionAuthentication): - update_schema[key] = vars(update_schema[key]) - for k, v in update_schema[key].items(): - if v is None: - pytest.assume(vars(res_dict[key])[k] == v) - elif k == "type": - pytest.assume(vars(res_dict[key])[k] == v) - else: - pytest.assume("*" in vars(res_dict[key])[k]) - - @pytest.mark.run(order=15) - @pytest.mark.asyncio - async def test_a_run_action(self): - - # Run an action. - - parameters = { - "number": 42 - } - res = await a_run_action(action_id=self.action_id, parameters=parameters) - logger.info(f'async run action{res}') - pytest.assume(res['status'] == 200) - pytest.assume(res["data"]) - - @pytest.mark.run(order=80) - @pytest.mark.asyncio - async def test_a_delete_action(self): - - # List actions. - - actions = await a_list_actions(limit=100) - old_nums = len(actions) - - for index, action in enumerate(actions): - action_id = action.action_id - - # Delete an action. - - await a_delete_action(action_id=action_id) - if index == old_nums - 1: - new_actions = await a_list_actions() - new_nums = len(new_actions) - pytest.assume(new_nums == 0) diff --git a/test/testcase/test_sync/__init__.py b/test/testcase/test_sync/__init__.py index e69de29..b5ac17b 100644 --- a/test/testcase/test_sync/__init__.py +++ b/test/testcase/test_sync/__init__.py @@ -0,0 +1,4 @@ +class Base: + + collection_id, record_id, chunk_id, action_id, assistant_id, chat_id, message_id = None, None, None, None, None, None, None + qa_collection_id, qa_record_id = None, None \ No newline at end of file diff --git a/test/testcase/test_sync/conftest.py b/test/testcase/test_sync/conftest.py index 966cc8d..d538782 100644 --- a/test/testcase/test_sync/conftest.py +++ b/test/testcase/test_sync/conftest.py @@ -1,10 +1,10 @@ import pytest - -from taskingai.retrieval import list_collections, list_records, list_chunks +# +# from taskingai.retrieval import list_collections, list_records, list_chunks from taskingai.assistant import list_assistants, list_chats, list_messages -from taskingai.tool import list_actions - - +# from taskingai.tool import list_actions +# +# @pytest.fixture(scope="session") def assistant_id(): res = list_assistants() @@ -24,34 +24,3 @@ def message_id(assistant_id, chat_id): res = list_messages(str(assistant_id), str(chat_id)) message_id = res[0].message_id return message_id - - -@pytest.fixture(scope="session") -def collection_id(): - res = list_collections() - collection_id = res[0].collection_id - return collection_id - - -@pytest.fixture(scope="session") -def record_id(collection_id): - res = list_records(str(collection_id)) - record_id = res[-1].record_id - return record_id - - -@pytest.fixture(scope="session") -def chunk_id(collection_id): - res = list_chunks(str(collection_id)) - chunk_id = res[0].chunk_id - return chunk_id - - -@pytest.fixture(scope="session") -def action_id(): - res = list_actions() - action_id = res[-1].action_id - return action_id - - - diff --git a/test/testcase/test_sync/test_sync_assistant.py b/test/testcase/test_sync/test_sync_assistant.py index 1dcc275..9dcb59d 100644 --- a/test/testcase/test_sync/test_sync_assistant.py +++ b/test/testcase/test_sync/test_sync_assistant.py @@ -2,66 +2,59 @@ from taskingai.assistant import * from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig -from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory +from taskingai.assistant.memory import AssistantMessageWindowMemory from test.config import Config from test.common.logger import logger -from test.common.utils import assume_assistant_result, assume_chat_result, assume_message_result +from test.common.utils import assume_assistant_result +from test.testcase.test_sync import Base @pytest.mark.test_sync -class TestAssistant: - +class TestAssistant(Base): @pytest.mark.run(order=51) - def test_create_assistant(self, collection_id, action_id): - + def test_create_assistant(self): # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.","No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ - RetrievalRef( + RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], - "retrieval_configs": RetrievalConfig( - method="memory", - top_k=1, - max_tokens=5000, - score_threshold=0.5 - - ), + "retrieval_configs": RetrievalConfig(method="memory", top_k=1, max_tokens=5000, score_threshold=0.5), "tools": [ - ToolRef( - type=ToolType.ACTION, - id=action_id, - ), - ToolRef( + ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } for i in range(4): if i == 0: - assistant_dict.update({"memory": {"type": "naive"}}) - assistant_dict.update({"retrievals": [{"type": "collection", "id": collection_id}]}) - assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}}) - assistant_dict.update({"tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) + assistant_dict.update({"memory": {"type": "message_window", "max_messages": 50, "max_tokens": 2000}}) + assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]}) + assistant_dict.update( + {"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}} + ) + assistant_dict.update({"tools": [{"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) res = create_assistant(**assistant_dict) res_dict = vars(res) - logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, res_dict) @pytest.mark.run(order=52) def test_list_assistants(self): - # List assistants. nums_limit = 1 @@ -85,7 +78,6 @@ def test_list_assistants(self): @pytest.mark.run(order=53) def test_get_assistant(self, assistant_id): - # Get an assistant. res = get_assistant(assistant_id=assistant_id) @@ -93,60 +85,46 @@ def test_get_assistant(self, assistant_id): pytest.assume(res_dict["assistant_id"] == assistant_id) @pytest.mark.run(order=54) - def test_update_assistant(self, collection_id, action_id, assistant_id): - + def test_update_assistant(self, assistant_id): # Update an assistant. update_data_list = [ { "name": "openai", "description": "test for openai", - "memory": AssistantZeroMemory(), + "memory": AssistantMessageWindowMemory(max_tokens=2000), "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], - "retrieval_configs": RetrievalConfig( - method="memory", - top_k=2, - max_tokens=4000, - score_threshold=0.5 - - ), + "retrieval_configs": RetrievalConfig(method="memory", top_k=2, max_tokens=4000, score_threshold=0.5), "tools": [ - ToolRef( - type=ToolType.ACTION, - id=action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], }, { "name": "openai", "description": "test for openai", - "memory": {"type": "naive"}, - "retrievals": [{"type": "collection", "id": collection_id}], + "memory": {"type": "message_window", "max_messages": 50, "max_tokens": 2000}, + "retrievals": [{"type": "collection", "id": self.collection_id}], "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}, - "tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] - - } + "tools": [{"type": "plugin", "id": "open_weather/get_hourly_forecast"}], + }, ] for update_data in update_data_list: - res = update_assistant(assistant_id=assistant_id, **update_data) res_dict = vars(res) - logger.info(f'response_dict:{res_dict}, except_dict:{update_data}') + logger.info(f"response_dict:{res_dict}, except_dict:{update_data}") assume_assistant_result(update_data, res_dict) @pytest.mark.run(order=66) def test_delete_assistant(self): - # List assistants. assistants = list_assistants(limit=100) @@ -159,7 +137,7 @@ def test_delete_assistant(self): delete_assistant(assistant_id=str(assistant_id)) # List assistants. - if i == old_nums-1: + if i == old_nums - 1: new_assistants = list_assistants() new_nums = len(new_assistants) pytest.assume(new_nums == 0) @@ -167,93 +145,83 @@ def test_delete_assistant(self): @pytest.mark.test_sync class TestChat: + @pytest.mark.run(order=55) + def test_create_chat(self, assistant_id): + for x in range(2): + # Create a chat. + name = f"test_chat{x+1}" + res = create_chat(assistant_id=assistant_id, name=name) + res_dict = vars(res) + pytest.assume(res_dict["name"] == name) - @pytest.mark.run(order=55) - def test_create_chat(self, assistant_id): - - for x in range(2): - - # Create a chat. - name = f"test_chat{x+1}" - res = create_chat(assistant_id=assistant_id, name=name) - res_dict = vars(res) - pytest.assume(res_dict["name"] == name) - - @pytest.mark.run(order=56) - def test_list_chats(self, assistant_id): + @pytest.mark.run(order=56) + def test_list_chats(self, assistant_id): + # List chats. - # List chats. + nums_limit = 1 + res = list_chats(limit=nums_limit, assistant_id=assistant_id) + pytest.assume(len(res) == nums_limit) - nums_limit = 1 - res = list_chats(limit=nums_limit, assistant_id=assistant_id) - pytest.assume(len(res) == nums_limit) + after_id = res[-1].chat_id + after_res = list_chats(limit=nums_limit, after=after_id, assistant_id=assistant_id) + pytest.assume(len(after_res) == nums_limit) - after_id = res[-1].chat_id - after_res = list_chats(limit=nums_limit, after=after_id, assistant_id=assistant_id) - pytest.assume(len(after_res) == nums_limit) + twice_nums_list = list_chats(limit=nums_limit * 2, assistant_id=assistant_id) + pytest.assume(len(twice_nums_list) == nums_limit * 2) + pytest.assume(after_res[-1] == twice_nums_list[-1]) + pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - twice_nums_list = list_chats(limit=nums_limit * 2, assistant_id=assistant_id) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) + before_id = after_res[0].chat_id + before_res = list_chats(limit=nums_limit, before=before_id, assistant_id=assistant_id) + pytest.assume(len(before_res) == nums_limit) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) + pytest.assume(before_res[0] == twice_nums_list[0]) - before_id = after_res[0].chat_id - before_res = list_chats(limit=nums_limit, before=before_id, assistant_id=assistant_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) + @pytest.mark.run(order=57) + def test_get_chat(self, assistant_id, chat_id): + # Get a chat. - @pytest.mark.run(order=57) - def test_get_chat(self, assistant_id, chat_id): + res = get_chat(assistant_id=assistant_id, chat_id=chat_id) + res_dict = vars(res) + pytest.assume(res_dict["chat_id"] == chat_id) + pytest.assume(res_dict["assistant_id"] == assistant_id) - # Get a chat. + @pytest.mark.run(order=58) + def test_update_chat(self, assistant_id, chat_id): + # Update a chat. - res = get_chat(assistant_id=assistant_id, chat_id=chat_id) - res_dict = vars(res) - pytest.assume(res_dict["chat_id"] == chat_id) - pytest.assume(res_dict["assistant_id"] == assistant_id) + metadata = {"test": "test"} + name = "test_update_chat" + res = update_chat(assistant_id=assistant_id, chat_id=chat_id, metadata=metadata, name=name) + res_dict = vars(res) + pytest.assume(res_dict["metadata"] == metadata) + pytest.assume(res_dict["name"] == name) - @pytest.mark.run(order=58) - def test_update_chat(self, assistant_id, chat_id): + @pytest.mark.run(order=65) + def test_delete_chat(self, assistant_id): + # List chats. - # Update a chat. + chats = list_chats(assistant_id=assistant_id) + old_nums = len(chats) + for index, chat in enumerate(chats): + chat_id = chat.chat_id - metadata = {"test": "test"} - name = "test_update_chat" - res = update_chat(assistant_id=assistant_id, chat_id=chat_id, metadata=metadata, name=name) - res_dict = vars(res) - pytest.assume(res_dict["metadata"] == metadata) - pytest.assume(res_dict["name"] == name) + # Delete a chat. - @pytest.mark.run(order=65) - def test_delete_chat(self, assistant_id): + delete_chat(assistant_id=assistant_id, chat_id=str(chat_id)) # List chats. - - chats = list_chats(assistant_id=assistant_id) - old_nums = len(chats) - for index, chat in enumerate(chats): - chat_id = chat.chat_id - - # Delete a chat. - - delete_chat(assistant_id=assistant_id, chat_id=str(chat_id)) - - # List chats. - if index == old_nums-1: - new_chats = list_chats(assistant_id=assistant_id) - new_nums = len(new_chats) - pytest.assume(new_nums == 0) + if index == old_nums - 1: + new_chats = list_chats(assistant_id=assistant_id) + new_nums = len(new_chats) + pytest.assume(new_nums == 0) @pytest.mark.test_sync -class TestMessage: - +class TestMessage(Base): @pytest.mark.run(order=59) def test_create_message(self, assistant_id, chat_id): - for x in range(2): - # Create a user message. text = "hello, what is the weather like in HongKong?" @@ -265,31 +233,26 @@ def test_create_message(self, assistant_id, chat_id): @pytest.mark.run(order=60) def test_list_messages(self, assistant_id, chat_id): - # List messages. nums_limit = 1 res = list_messages(limit=nums_limit, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(res) == nums_limit) after_id = res[-1].message_id - after_res = list_messages(limit=nums_limit, after=after_id, assistant_id=assistant_id, - chat_id=chat_id) + after_res = list_messages(limit=nums_limit, after=after_id, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(after_res) == nums_limit) - twice_nums_list = list_messages(limit=nums_limit * 2, assistant_id=assistant_id, - chat_id=chat_id) + twice_nums_list = list_messages(limit=nums_limit * 2, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(twice_nums_list) == nums_limit * 2) pytest.assume(after_res[-1] == twice_nums_list[-1]) pytest.assume(after_res[0] == twice_nums_list[nums_limit]) before_id = after_res[0].message_id - before_res = list_messages(limit=nums_limit, before=before_id, assistant_id=assistant_id, - chat_id=chat_id) + before_res = list_messages(limit=nums_limit, before=before_id, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(before_res) == nums_limit) pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) @pytest.mark.run(order=61) def test_get_message(self, assistant_id, chat_id, message_id): - # Get a message. res = get_message(assistant_id=assistant_id, chat_id=chat_id, message_id=message_id) @@ -300,7 +263,6 @@ def test_get_message(self, assistant_id, chat_id, message_id): @pytest.mark.run(order=62) def test_update_message(self, assistant_id, chat_id, message_id): - # Update a message. metadata = {"test": "test"} @@ -310,7 +272,6 @@ def test_update_message(self, assistant_id, chat_id, message_id): @pytest.mark.run(order=63) def test_generate_message(self, assistant_id, chat_id): - # Generate an assistant message by no stream. res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}) @@ -322,39 +283,47 @@ def test_generate_message(self, assistant_id, chat_id): pytest.assume(vars(res_dict["content"])["text"] is not None) @pytest.mark.run(order=64) - def test_generate_message_by_stream(self, collection_id, action_id): + def test_clean_chat_context(self, assistant_id, chat_id): + # Generate an assistant message by no stream. + res = clean_chat_context(assistant_id=assistant_id, chat_id=chat_id) + res_dict = vars(res) + pytest.assume(res_dict["role"] == "system") + pytest.assume(res_dict["content"] is not None) + pytest.assume(res_dict["assistant_id"] == assistant_id) + pytest.assume(res_dict["chat_id"] == chat_id) + pytest.assume(vars(res_dict["content"])["text"] == "context_cleaned") + + @pytest.mark.run(order=64) + def test_generate_message_by_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], "retrieval_configs": RetrievalConfig( method="memory", top_k=1, max_tokens=5000, - ), "tools": [ - ToolRef( - type=ToolType.ACTION, - id=action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } assistant_res = create_assistant(**assistant_dict) assistant_id = assistant_res.assistant_id @@ -367,13 +336,14 @@ def test_generate_message_by_stream(self, collection_id, action_id): # create user message user_message: Message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 10 and separate numbers by comma.") + assistant_id=assistant_id, chat_id=chat_id, text="count from 1 to 10 and separate numbers by comma." + ) # Generate an assistant message by stream. - stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True) + stream_res = generate_message( + assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}, stream=True + ) except_list = ["MessageChunk", "Message"] real_list = [] for item in stream_res: @@ -389,45 +359,42 @@ def test_generate_message_by_stream(self, collection_id, action_id): pytest.assume(set(except_list) == set(real_list)) @pytest.mark.run(order=70) - def test_assistant_by_user_message_retrieval_and_stream(self, collection_id): - + def test_assistant_by_user_message_retrieval_and_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], - "retrieval_configs": { - "method": "user_message", - "top_k": 1, - "max_tokens": 5000, - "score_threshold": 0.5 - } + "retrieval_configs": {"method": "user_message", "top_k": 1, "max_tokens": 5000, "score_threshold": 0.5}, } assistant_res = create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = create_message(assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, - text=text) - generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -439,46 +406,42 @@ def test_assistant_by_user_message_retrieval_and_stream(self, collection_id): assert final_content is not None @pytest.mark.run(order=70) - def test_assistant_by_memory_retrieval_and_stream(self, collection_id): - + def test_assistant_by_memory_retrieval_and_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], - "retrieval_configs": { - "method": "memory", - "top_k": 1, - "max_tokens": 5000, - "score_threshold": 0.5 - - } + "retrieval_configs": {"method": "memory", "top_k": 1, "max_tokens": 5000, "score_threshold": 0.5}, } assistant_res = create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = create_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, text=text) - generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -490,46 +453,42 @@ def test_assistant_by_memory_retrieval_and_stream(self, collection_id): assert final_content is not None @pytest.mark.run(order=70) - def test_assistant_by_function_call_retrieval_and_stream(self, collection_id): - + def test_assistant_by_function_call_retrieval_and_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], - "retrieval_configs": - { - "method": "function_call", - "top_k": 1, - "max_tokens": 5000, - "score_threshold": 0.5 - } + "retrieval_configs": {"method": "function_call", "top_k": 1, "max_tokens": 5000, "score_threshold": 0.5}, } assistant_res = create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = create_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, text=text) - generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -541,73 +500,71 @@ def test_assistant_by_function_call_retrieval_and_stream(self, collection_id): assert final_content is not None @pytest.mark.run(order=70) - def test_assistant_by_not_support_function_call_retrieval_and_stream(self, collection_id): - + def test_assistant_by_not_support_function_call_retrieval_and_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.anthropic_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "retrievals": [ RetrievalRef( type=RetrievalType.COLLECTION, - id=collection_id, + id=self.collection_id, ), ], "retrieval_configs": RetrievalConfig( method="function_call", top_k=1, max_tokens=5000, - - ) + ), } with pytest.raises(Exception) as e: assistant_res = create_assistant(**assistant_dict) assert "not support function call to use retrieval" in str(e.value) @pytest.mark.run(order=70) - def test_assistant_by_all_tool_and_stream(self, action_id): - + def test_assistant_by_all_tool_and_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.openai_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "tools": [ - ToolRef( - type=ToolType.ACTION, - id=action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } assistant_res = create_assistant(**assistant_dict) assistant_res_dict = vars(assistant_res) - logger.info(f'response_dict:{assistant_res_dict}, except_dict:{assistant_dict}') + logger.info(f"response_dict:{assistant_res_dict}, except_dict:{assistant_dict}") assume_assistant_result(assistant_dict, assistant_res_dict) chat_res = create_chat(assistant_id=assistant_res.assistant_id, name="test_chat") text = "hello, what is the weather like in HongKong?" - create_message_res = create_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, text=text) - generate_message_res = generate_message(assistant_id=assistant_res.assistant_id, - chat_id=chat_res.chat_id, system_prompt_variables={}, - stream=True) - final_content = '' + create_message_res = create_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, text=text + ) + generate_message_res = generate_message( + assistant_id=assistant_res.assistant_id, chat_id=chat_res.chat_id, system_prompt_variables={}, stream=True + ) + final_content = "" for item in generate_message_res: if isinstance(item, MessageChunk): logger.info(f"MessageChunk: {item.delta}") @@ -619,30 +576,26 @@ def test_assistant_by_all_tool_and_stream(self, action_id): assert final_content is not None @pytest.mark.run(order=70) - def test_assistant_by_not_support_function_call_tool_and_stream(self, action_id): - + def test_assistant_by_not_support_function_call_tool_and_stream(self): # Create an assistant. assistant_dict = { "model_id": Config.anthropic_chat_completion_model_id, "name": "test", "description": "test for assistant", - "memory": AssistantNaiveMemory(), - "system_prompt_template": ["You know the meaning of various numbers.", - "No matter what the user's language is, you will use the {{langugae}} to explain."], + "memory": AssistantMessageWindowMemory(max_tokens=2000), + "system_prompt_template": [ + "You know the meaning of various numbers.", + "No matter what the user's language is, you will use the {{langugae}} to explain.", + ], "metadata": {"test": "test"}, "tools": [ - ToolRef( - type=ToolType.ACTION, - id=action_id, - ), ToolRef( type=ToolType.PLUGIN, id="open_weather/get_hourly_forecast", ) - ] + ], } - with pytest.raises(Exception) as e: - assistant_res = create_assistant(**assistant_dict) - assert "not support function call to use the tools" in str(e.value) + assistant_res = create_assistant(**assistant_dict) + assistant_res_dict = vars(assistant_res) diff --git a/test/testcase/test_sync/test_sync_retrieval.py b/test/testcase/test_sync/test_sync_retrieval.py index 1937f21..9da3f2d 100644 --- a/test/testcase/test_sync/test_sync_retrieval.py +++ b/test/testcase/test_sync/test_sync_retrieval.py @@ -13,12 +13,7 @@ get_record, update_record, delete_record, - query_chunks, - create_chunk, - update_chunk, - get_chunk, - delete_chunk, - list_chunks, + query_chunks ) from taskingai.file import upload_file from test.config import Config @@ -29,25 +24,42 @@ assume_chunk_result, assume_query_chunk_result, ) +from test.testcase.test_sync import Base @pytest.mark.test_sync -class TestCollection: +class TestCollection(Base): @pytest.mark.run(order=21) def test_create_collection(self): # Create a collection. - create_dict = { + create_list = [ + { "capacity": 1000, "embedding_model_id": Config.openai_text_embedding_model_id, "name": "test", "description": "description", "metadata": {"key1": "value1", "key2": "value2"}, - } - for x in range(2): + }, + { + "capacity": 1000, + "embedding_model_id": Config.openai_text_embedding_model_id, + "type": "qa", + "name": "test", + "description": "description", + "metadata": {"key1": "value1", "key2": "value2"}, + }, + + ] + for index, create_dict in enumerate(create_list): res = create_collection(**create_dict) res_dict = vars(res) logger.info(res_dict) assume_collection_result(create_dict, res_dict) + if index == 0: + Base.collection_id = res_dict["collection_id"] + else: + Base.qa_collection_id = res_dict["collection_id"] + @pytest.mark.run(order=22) def test_list_collections(self): @@ -70,20 +82,20 @@ def test_list_collections(self): pytest.assume(before_res[0] == twice_nums_list[0]) @pytest.mark.run(order=23) - def test_get_collection(self, collection_id): + def test_get_collection(self): # Get a collection. - res = get_collection(collection_id=collection_id) + res = get_collection(collection_id=self.collection_id) res_dict = vars(res) pytest.assume(res_dict["status"] == "ready") - pytest.assume(res_dict["collection_id"] == collection_id) + pytest.assume(res_dict["collection_id"] == self.collection_id) @pytest.mark.run(order=24) - def test_update_collection(self, collection_id): + def test_update_collection(self): # Update a collection. update_collection_data = { - "collection_id": collection_id, + "collection_id": self.collection_id, "name": "test_update", "description": "description_update", "metadata": {"key1": "value1", "key2": "value2"}, @@ -115,7 +127,7 @@ def test_delete_collection(self): @pytest.mark.test_sync -class TestRecord: +class TestRecord(Base): text_splitter_list = [ # { # "type": "token", @@ -127,7 +139,7 @@ class TestRecord: TextSplitter(type="separator", chunk_size=200, chunk_overlap=20, separators=[".", "!", "?"]), ] upload_file_data_list = [] - + upload_qa_file_data_list = [] base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) files = os.listdir(base_path + "/files") for file in files: @@ -137,15 +149,23 @@ class TestRecord: upload_file_dict.update({"file": open(filepath, "rb")}) upload_file_data_list.append(upload_file_dict) + qa_files = os.listdir(base_path + "/qa_files") + for file in qa_files: + filepath = os.path.join(base_path, "qa_files", file) + if os.path.isfile(filepath): + upload_qa_file_dict = {"purpose": "qa_record_file"} + upload_qa_file_dict.update({"file": open(filepath, "rb")}) + upload_qa_file_data_list.append(upload_qa_file_dict) + @pytest.mark.run(order=31) @pytest.mark.parametrize("text_splitter", text_splitter_list) - def test_create_record_by_text(self, collection_id, text_splitter): + def test_create_record_by_text(self, text_splitter): # Create a text record. text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." create_record_data = { "type": "text", "title": "Machine learning", - "collection_id": collection_id, + "collection_id": self.collection_id, "content": text, "text_splitter": text_splitter, "metadata": {"key1": "value1", "key2": "value2"}, @@ -155,13 +175,13 @@ def test_create_record_by_text(self, collection_id, text_splitter): assume_record_result(create_record_data, res_dict) @pytest.mark.run(order=31) - def test_create_record_by_web(self, collection_id): + def test_create_record_by_web(self): # Create a web record. text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) create_record_data = { "type": "web", "title": "TaskingAI", - "collection_id": collection_id, + "collection_id": self.collection_id, "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", "text_splitter": text_splitter, "metadata": {"key1": "value1", "key2": "value2"}, @@ -173,7 +193,7 @@ def test_create_record_by_web(self, collection_id): @pytest.mark.run(order=32) @pytest.mark.parametrize("upload_file_data", upload_file_data_list[:2]) - def test_create_record_by_file(self, collection_id, upload_file_data): + def test_create_record_by_file(self, upload_file_data): # upload file upload_file_res = upload_file(**upload_file_data) upload_file_dict = vars(upload_file_res) @@ -184,7 +204,7 @@ def test_create_record_by_file(self, collection_id, upload_file_data): create_record_data = { "type": "file", "title": "TaskingAI", - "collection_id": collection_id, + "collection_id": self.collection_id, "file_id": file_id, "text_splitter": text_splitter, "metadata": {"key1": "value1", "key2": "value2"}, @@ -195,81 +215,146 @@ def test_create_record_by_file(self, collection_id, upload_file_data): assume_record_result(create_record_data, res_dict) @pytest.mark.run(order=32) - def test_list_records(self, collection_id): + @pytest.mark.parametrize("upload_qa_file_data", upload_qa_file_data_list) + def test_create_record_by_qa_file(self, upload_qa_file_data): + # upload file + upload_file_res = upload_file(**upload_qa_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) + create_record_data = { + "type": "qa_sheet", + "collection_id": self.qa_collection_id, + "file_id": file_id, + + "metadata": {"key1": "value1", "key2": "value2"}, + } + + res = create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + + @pytest.mark.run(order=32) + def test_list_records(self): + # List records. + + nums_limit = 1 + res = list_records(limit=nums_limit, collection_id=self.collection_id) + pytest.assume(len(res) == nums_limit) + + after_id = res[-1].record_id + after_res = list_records(limit=nums_limit, after=after_id, collection_id=self.collection_id) + pytest.assume(len(after_res) == nums_limit) + + twice_nums_list = list_records(limit=nums_limit * 2, collection_id=self.collection_id) + pytest.assume(len(twice_nums_list) == nums_limit * 2) + pytest.assume(after_res[-1] == twice_nums_list[-1]) + pytest.assume(after_res[0] == twice_nums_list[nums_limit]) + + before_id = after_res[0].record_id + before_res = list_records(limit=nums_limit, before=before_id, collection_id=self.collection_id) + pytest.assume(len(before_res) == nums_limit) + pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) + pytest.assume(before_res[0] == twice_nums_list[0]) + + @pytest.mark.run(order=32) + def test_list_qa_records(self): # List records. nums_limit = 1 - res = list_records(limit=nums_limit, collection_id=collection_id) + res = list_records(limit=nums_limit, collection_id=self.qa_collection_id) pytest.assume(len(res) == nums_limit) after_id = res[-1].record_id - after_res = list_records(limit=nums_limit, after=after_id, collection_id=collection_id) + after_res = list_records(limit=nums_limit, after=after_id, collection_id=self.qa_collection_id) pytest.assume(len(after_res) == nums_limit) - twice_nums_list = list_records(limit=nums_limit * 2, collection_id=collection_id) + twice_nums_list = list_records(limit=nums_limit * 2, collection_id=self.qa_collection_id) pytest.assume(len(twice_nums_list) == nums_limit * 2) pytest.assume(after_res[-1] == twice_nums_list[-1]) pytest.assume(after_res[0] == twice_nums_list[nums_limit]) before_id = after_res[0].record_id - before_res = list_records(limit=nums_limit, before=before_id, collection_id=collection_id) + before_res = list_records(limit=nums_limit, before=before_id, collection_id=self.qa_collection_id) pytest.assume(len(before_res) == nums_limit) pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) pytest.assume(before_res[0] == twice_nums_list[0]) @pytest.mark.run(order=33) - def test_get_record(self, collection_id): + def test_get_record(self): # list records - records = list_records(collection_id=collection_id) + records = list_records(collection_id=self.collection_id) for record in records: record_id = record.record_id - res = get_record(collection_id=collection_id, record_id=record_id) + time.sleep(Config.sleep_time) + res = get_record(collection_id=self.collection_id, record_id=record_id) logger.info(f"get record response: {res}") res_dict = vars(res) - pytest.assume(res_dict["collection_id"] == collection_id) + pytest.assume(res_dict["collection_id"] == self.collection_id) pytest.assume(res_dict["record_id"] == record_id) pytest.assume(res_dict["status"] == "ready") + Base.record_id = record_id + + @pytest.mark.run(order=33) + def test_get_qa_record(self): + # list records + + records = list_records(collection_id=self.qa_collection_id) + for record in records: + record_id = record.record_id + time.sleep(Config.sleep_time) + res = get_record(collection_id=self.qa_collection_id, record_id=record_id) + logger.info(f"get record response: {res}") + res_dict = vars(res) + pytest.assume(res_dict["collection_id"] == self.qa_collection_id) + pytest.assume(res_dict["record_id"] == record_id) + pytest.assume(res_dict["status"] == "ready") + Base.qa_record_id = record_id @pytest.mark.run(order=34) @pytest.mark.parametrize("text_splitter", text_splitter_list) - def test_update_record_by_text(self, collection_id, record_id, text_splitter): + def test_update_record_by_text(self, text_splitter): # Update a record. update_record_data = { "type": "text", "title": "TaskingAI", - "collection_id": collection_id, - "record_id": record_id, + "collection_id": self.collection_id, + "record_id": self.record_id, "content": "TaskingAI is an AI-native application development platform that unifies modules like Model, Retrieval, Assistant, and Tool into one seamless ecosystem, streamlining the creation and deployment of applications for developers.", "text_splitter": text_splitter, "metadata": {"test": "test"}, } res = update_record(**update_record_data) + time.sleep(Config.sleep_time) res_dict = vars(res) assume_record_result(update_record_data, res_dict) @pytest.mark.run(order=34) @pytest.mark.parametrize("text_splitter", text_splitter_list) - def test_update_record_by_web(self, collection_id, record_id, text_splitter): + def test_update_record_by_web(self, text_splitter): # Update a record. update_record_data = { "type": "web", "title": "TaskingAI", - "collection_id": collection_id, - "record_id": record_id, + "collection_id": self.collection_id, + "record_id": self.record_id, "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", "text_splitter": text_splitter, "metadata": {"test": "test"}, } res = update_record(**update_record_data) + time.sleep(Config.sleep_time) res_dict = vars(res) assume_record_result(update_record_data, res_dict) @pytest.mark.run(order=35) @pytest.mark.parametrize("upload_file_data", upload_file_data_list[2:3]) - def test_update_record_by_file(self, collection_id, record_id, upload_file_data): + def test_update_record_by_file(self, upload_file_data): # upload file upload_file_res = upload_file(**upload_file_data) upload_file_dict = vars(upload_file_res) @@ -282,8 +367,8 @@ def test_update_record_by_file(self, collection_id, record_id, upload_file_data) update_record_data = { "type": "file", "title": "TaskingAI", - "collection_id": collection_id, - "record_id": record_id, + "collection_id": self.collection_id, + "record_id": self.record_id, "file_id": file_id, "text_splitter": text_splitter, "metadata": {"test": "test"}, @@ -292,120 +377,84 @@ def test_update_record_by_file(self, collection_id, record_id, upload_file_data) res_dict = vars(res) assume_record_result(update_record_data, res_dict) + @pytest.mark.run(order=35) + @pytest.mark.parametrize("upload_qa_file_data", upload_qa_file_data_list) + def test_update_qa_record(self, upload_qa_file_data): + # upload file + upload_file_res = upload_file(**upload_qa_file_data) + upload_file_dict = vars(upload_file_res) + file_id = upload_file_dict["file_id"] + pytest.assume(file_id is not None) + + # Update a record. + text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) + + update_record_data = { + "type": "qa_sheet", + "collection_id": self.qa_collection_id, + "record_id": self.qa_record_id, + "file_id": file_id, + "metadata": {"test": "test"}, + } + res = update_record(**update_record_data) + res_dict = vars(res) + # assume_record_result(update_record_data, res_dict) + @pytest.mark.run(order=79) - def test_delete_record(self, collection_id): + def test_delete_record(self): # List records. time.sleep(Config.sleep_time) - records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, before=None) + records = list_records(collection_id=self.collection_id, order="desc", limit=20, after=None, before=None) old_nums = len(records) for index, record in enumerate(records): record_id = record.record_id # Delete a record. - delete_record(collection_id=collection_id, record_id=record_id) + delete_record(collection_id=self.collection_id, record_id=record_id) # List records. if index == old_nums - 1: - new_records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, before=None) + new_records = list_records(collection_id=self.collection_id, order="desc", limit=20, after=None, before=None) + + new_nums = len(new_records) + pytest.assume(new_nums == 0) + + @pytest.mark.run(order=79) + def test_delete_qa_record(self): + # List records. + time.sleep(Config.sleep_time) + records = list_records(collection_id=self.qa_collection_id, order="desc", limit=20, after=None, before=None) + old_nums = len(records) + for index, record in enumerate(records): + record_id = record.record_id + + # Delete a record. + + delete_record(collection_id=self.qa_collection_id, record_id=record_id) + + # List records. + if index == old_nums - 1: + new_records = list_records(collection_id=self.qa_collection_id, order="desc", limit=20, after=None, before=None) new_nums = len(new_records) pytest.assume(new_nums == 0) @pytest.mark.test_sync -class TestChunk: +class TestChunk(Base): @pytest.mark.run(order=41) - def test_query_chunks(self, collection_id): + def test_query_chunks(self): # Query chunks. query_text = "Machine learning" top_k = 1 res = query_chunks( - collection_id=collection_id, query_text=query_text, top_k=top_k, max_tokens=20000, score_threshold=0.04 + collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000, score_threshold=0.04 ) pytest.assume(len(res) == top_k) for chunk in res: chunk_dict = vars(chunk) assume_query_chunk_result(query_text, chunk_dict) pytest.assume(chunk_dict["score"] >= 0.04) - - @pytest.mark.run(order=42) - def test_create_chunk(self, collection_id): - # Create a chunk. - create_chunk_data = { - "collection_id": collection_id, - "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", - } - res = create_chunk(**create_chunk_data) - res_dict = vars(res) - assume_chunk_result(create_chunk_data, res_dict) - - @pytest.mark.run(order=43) - def test_list_chunks(self, collection_id): - # List chunks. - - nums_limit = 1 - res = list_chunks(limit=nums_limit, collection_id=collection_id) - pytest.assume(len(res) == nums_limit) - - after_id = res[-1].chunk_id - after_res = list_chunks(limit=nums_limit, after=after_id, collection_id=collection_id) - pytest.assume(len(after_res) == nums_limit) - - twice_nums_list = list_chunks(limit=nums_limit * 2, collection_id=collection_id) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - - before_id = after_res[0].chunk_id - before_res = list_chunks(limit=nums_limit, before=before_id, collection_id=collection_id) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) - - @pytest.mark.run(order=44) - def test_get_chunk(self, collection_id): - # list chunks - - chunks = list_chunks(collection_id=collection_id) - for chunk in chunks: - chunk_id = chunk.chunk_id - res = get_chunk(collection_id=collection_id, chunk_id=chunk_id) - logger.info(f"get chunk response: {res}") - res_dict = vars(res) - pytest.assume(res_dict["collection_id"] == collection_id) - pytest.assume(res_dict["chunk_id"] == chunk_id) - - @pytest.mark.run(order=45) - def test_update_chunk(self, collection_id, chunk_id): - # Update a chunk. - - update_chunk_data = { - "collection_id": collection_id, - "chunk_id": chunk_id, - "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", - "metadata": {"test": "test"}, - } - res = update_chunk(**update_chunk_data) - res_dict = vars(res) - assume_chunk_result(update_chunk_data, res_dict) - - @pytest.mark.run(order=46) - def test_delete_chunk(self, collection_id): - # List chunks. - - chunks = list_chunks(collection_id=collection_id, limit=5) - for index, chunk in enumerate(chunks): - chunk_id = chunk.chunk_id - - # Delete a chunk. - - delete_chunk(collection_id=collection_id, chunk_id=chunk_id) - - # List chunks. - - new_chunks = list_chunks(collection_id=collection_id) - chunk_ids = [chunk.chunk_id for chunk in new_chunks] - pytest.assume(chunk_id not in chunk_ids) diff --git a/test/testcase/test_sync/test_sync_tool.py b/test/testcase/test_sync/test_sync_tool.py deleted file mode 100644 index 7c53cb6..0000000 --- a/test/testcase/test_sync/test_sync_tool.py +++ /dev/null @@ -1,221 +0,0 @@ -import pytest -from test.config import Config -from taskingai.tool import bulk_create_actions, get_action, update_action, delete_action, run_action, list_actions, ActionAuthentication, ActionAuthenticationType -from test.common.logger import logger - - -@pytest.mark.test_sync -class TestAction: - - authentication_list = [ - { - "type": "bearer", - "secret": "ASD213df" - }, - ActionAuthentication(type=ActionAuthenticationType.BEARER, secret="ASD213df") - ] - - @pytest.mark.run(order=11) - @pytest.mark.parametrize("authentication", authentication_list) - def test_bulk_create_actions(self, authentication): - schema = { - "openapi_schema": { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location 123", - "operationId": "GetCurrentWeather123", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - } - } - } - - } - - } - schema.update({"authentication": authentication}) - - # Create an action. - - res = bulk_create_actions(**schema) - for action in res: - action_dict = vars(action) - logger.info(action_dict) - for key in schema.keys(): - if key != "authentication": - for k, v in schema[key].items(): - pytest.assume(action_dict[key][k] == v) - else: - if isinstance(schema[key], ActionAuthentication): - schema[key] = vars(schema[key]) - for k, v in schema[key].items(): - if v is None: - pytest.assume(vars(action_dict[key])[k] == v) - elif k == "type": - pytest.assume(vars(action_dict[key])[k] == v) - else: - pytest.assume("*" in vars(action_dict[key])[k]) - - @pytest.mark.run(order=12) - def test_list_actions(self): - - # List actions. - - nums_limit = 1 - res = list_actions(limit=nums_limit) - logger.info(res) - pytest.assume(len(res) == nums_limit) - - after_id = res[-1].action_id - after_res = list_actions(limit=nums_limit, after=after_id) - logger.info(after_res) - pytest.assume(len(after_res) == nums_limit) - - twice_nums_list = list_actions(limit=nums_limit * 2) - logger.info(twice_nums_list) - pytest.assume(len(twice_nums_list) == nums_limit * 2) - pytest.assume(after_res[-1] == twice_nums_list[-1]) - pytest.assume(after_res[0] == twice_nums_list[nums_limit]) - - before_id = after_res[0].action_id - before_res = list_actions(limit=nums_limit, before=before_id) - logger.info(before_res) - pytest.assume(len(before_res) == nums_limit) - pytest.assume(before_res[-1] == twice_nums_list[nums_limit - 1]) - pytest.assume(before_res[0] == twice_nums_list[0]) - - @pytest.mark.run(order=13) - def test_get_action(self, action_id): - - # Get an action. - - res = get_action(action_id=action_id) - res_dict = vars(res) - logger.info(res_dict["openapi_schema"].keys()) - pytest.assume(res_dict["action_id"] == action_id) - - @pytest.mark.run(order=14) - @pytest.mark.parametrize("authentication", authentication_list) - def test_update_action(self, action_id, authentication): - - # Update an action. - - update_schema = { - "openapi_schema": { - "openapi": "3.0.0", - "info": { - "title": "Numbers API", - "version": "1.0.0", - "description": "API for fetching interesting number facts" - }, - "servers": [ - { - "url": "http://numbersapi.com" - } - ], - "paths": { - "/{number}": { - "get": { - "description": "Get fact about a number", - "operationId": "get_number_fact", - "parameters": [ - { - "name": "number", - "in": "path", - "required": True, - "description": "The number to get the fact for", - "schema": { - "type": "integer" - } - } - ], - "responses": { - "200": { - "description": "A fact about the number", - "content": { - "text/plain": { - "schema": { - "type": "string" - } - } - } - } - } - } - } - } - } - } - update_schema.update({"authentication": authentication}) - res = update_action(action_id=action_id, **update_schema) - res_dict = vars(res) - logger.info(res_dict) - for key in update_schema.keys(): - if key != "authentication": - for k, v in update_schema[key].items(): - pytest.assume(res_dict[key][k] == v) - else: - if isinstance(update_schema[key], ActionAuthentication): - update_schema[key] = vars(update_schema[key]) - for k, v in update_schema[key].items(): - if v is None: - pytest.assume(vars(res_dict[key])[k] == v) - elif k == "type": - pytest.assume(vars(res_dict[key])[k] == v) - else: - pytest.assume("*" in vars(res_dict[key])[k]) - - @pytest.mark.run(order=15) - def test_run_action(self, action_id): - - # Run an action. - - parameters = { - "number": 42 - } - res = run_action(action_id=action_id, parameters=parameters) - logger.info(f'async run action{res}') - pytest.assume(res['status'] == 200) - pytest.assume(res["data"]) - - @pytest.mark.run(order=80) - def test_delete_action(self): - - # List actions. - - actions = list_actions(limit=100) - old_nums = len(actions) - - for index, action in enumerate(actions): - action_id = action.action_id - - # Delete an action. - - delete_action(action_id=action_id) - - if index == old_nums-1: - new_actions = list_actions(limit=100) - new_nums = len(new_actions) - pytest.assume(new_nums == 0)