diff --git a/taskingai/_version.py b/taskingai/_version.py index b73ba3b..d7951dc 100644 --- a/taskingai/_version.py +++ b/taskingai/_version.py @@ -1,2 +1,2 @@ __title__ = "taskingai" -__version__ = "0.2.3" +__version__ = "0.2.4" diff --git a/test/common/utils.py b/test/common/utils.py index 7b3426f..dacb9bf 100644 --- a/test/common/utils.py +++ b/test/common/utils.py @@ -136,10 +136,11 @@ def assume_assistant_result(assistant_dict: dict, res: dict): if key == 'system_prompt_template' and isinstance(value, str): pytest.assume(res[key] == [assistant_dict[key]]) elif key in ['retrieval_configs']: - if isinstance(value, dict): - pytest.assume(vars(res[key]) == assistant_dict[key]) - else: - pytest.assume(res[key] == assistant_dict[key]) + continue + # if isinstance(value, dict): + # pytest.assume(vars(res[key]) == assistant_dict[key]) + # else: + # pytest.assume(res[key] == assistant_dict[key]) elif key in ["memory", "tools", "retrievals"]: continue else: diff --git a/test/testcase/test_async/test_async_assistant.py b/test/testcase/test_async/test_async_assistant.py index 5ab22a4..7237b1b 100644 --- a/test/testcase/test_async/test_async_assistant.py +++ b/test/testcase/test_async/test_async_assistant.py @@ -37,6 +37,7 @@ async def test_a_create_assistant(self): method="memory", top_k=1, max_tokens=5000, + score_threshold=0.5 ), "tools": [ @@ -54,7 +55,7 @@ async def test_a_create_assistant(self): if i == 0: assistant_dict.update({"memory": {"type": "naive"}}) assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]}) - assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}}) + assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}}) assistant_dict.update({"tools": [{"type": "action", "id": self.action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) res = await a_create_assistant(**assistant_dict) @@ -119,6 +120,7 @@ async def test_a_update_assistant(self): method="memory", top_k=2, max_tokens=4000, + score_threshold=0.5 ), "tools": [ @@ -137,7 +139,7 @@ async def test_a_update_assistant(self): "description": "test for openai", "memory": {"type": "naive"}, "retrievals": [{"type": "collection", "id": self.collection_id}], - "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}, + "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}, "tools": [{"type": "action", "id": self.action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] @@ -365,6 +367,7 @@ async def test_a_generate_message_by_stream(self): method="memory", top_k=1, max_tokens=5000, + score_threshold=0.04 ), "tools": [ @@ -435,7 +438,8 @@ async def test_a_assistant_by_user_message_retrieval_and_stream(self): "retrieval_configs": { "method": "user_message", "top_k": 1, - "max_tokens": 5000 + "max_tokens": 5000, + "score_threshold": 0.5 } } @@ -482,7 +486,8 @@ async def test_a_assistant_by_memory_retrieval_and_stream(self): "retrieval_configs": { "method": "memory", "top_k": 1, - "max_tokens": 5000 + "max_tokens": 5000, + "score_threshold": 0.5 } } @@ -534,7 +539,8 @@ async def test_a_assistant_by_function_call_retrieval_and_stream(self): { "method": "function_call", "top_k": 1, - "max_tokens": 5000 + "max_tokens": 5000, + "score_threshold": 0.5 } } diff --git a/test/testcase/test_async/test_async_retrieval.py b/test/testcase/test_async/test_async_retrieval.py index 96f037d..090558d 100644 --- a/test/testcase/test_async/test_async_retrieval.py +++ b/test/testcase/test_async/test_async_retrieval.py @@ -17,7 +17,6 @@ @pytest.mark.test_async class TestCollection(Base): - @pytest.mark.run(order=21) @pytest.mark.asyncio async def test_a_create_collection(self): @@ -101,10 +100,11 @@ async def test_a_delete_collection(self): @pytest.mark.test_async class TestRecord(Base): - text_splitter_list = [ - {"type": "token", "chunk_size": 100, "chunk_overlap": 10}, - TokenTextSplitter(chunk_size=200, chunk_overlap=20), + # {"type": "token", "chunk_size": 100, "chunk_overlap": 10}, + # TokenTextSplitter(chunk_size=200, chunk_overlap=20), + {"type": "separator", "chunk_size": 100, "chunk_overlap": 10, "separators": [".", "!", "?"]}, + TextSplitter(type="separator", chunk_size=200, chunk_overlap=20, separators=[".", "!", "?"]), ] upload_file_data_list = [] @@ -120,8 +120,8 @@ class TestRecord(Base): @pytest.mark.run(order=31) @pytest.mark.asyncio - async def test_a_create_record_by_text(self): - text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=100) + @pytest.mark.parametrize("text_splitter", text_splitter_list) + async def test_a_create_record_by_text(self, text_splitter): text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." create_record_data = { "type": "text", @@ -131,16 +131,10 @@ async def test_a_create_record_by_text(self): "text_splitter": text_splitter, "metadata": {"key1": "value1", "key2": "value2"}, } - - for x in range(2): - # Create a record. - if x == 0: - create_record_data.update({"text_splitter": {"type": "token", "chunk_size": 100, "chunk_overlap": 10}}) - - res = await a_create_record(**create_record_data) - res_dict = vars(res) - assume_record_result(create_record_data, res_dict) - Base.record_id = res_dict["record_id"] + res = await a_create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) + Base.record_id = res_dict["record_id"] @pytest.mark.run(order=31) @pytest.mark.asyncio @@ -332,13 +326,14 @@ async def test_a_query_chunks(self): query_text = "Machine learning" top_k = 1 res = await a_query_chunks( - collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000 + collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000, score_threshold=0.04 ) pytest.assume(len(res) == top_k) for chunk in res: chunk_dict = vars(chunk) assume_query_chunk_result(query_text, chunk_dict) pytest.assume(chunk_dict.keys() == self.chunk_keys) + pytest.assume(chunk_dict["score"] >= 0.04) @pytest.mark.run(order=42) @pytest.mark.asyncio diff --git a/test/testcase/test_sync/test_sync_assistant.py b/test/testcase/test_sync/test_sync_assistant.py index f9dfde0..1dcc275 100644 --- a/test/testcase/test_sync/test_sync_assistant.py +++ b/test/testcase/test_sync/test_sync_assistant.py @@ -33,6 +33,7 @@ def test_create_assistant(self, collection_id, action_id): method="memory", top_k=1, max_tokens=5000, + score_threshold=0.5 ), "tools": [ @@ -50,7 +51,7 @@ def test_create_assistant(self, collection_id, action_id): if i == 0: assistant_dict.update({"memory": {"type": "naive"}}) assistant_dict.update({"retrievals": [{"type": "collection", "id": collection_id}]}) - assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}}) + assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}}) assistant_dict.update({"tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) res = create_assistant(**assistant_dict) @@ -111,6 +112,7 @@ def test_update_assistant(self, collection_id, action_id, assistant_id): method="memory", top_k=2, max_tokens=4000, + score_threshold=0.5 ), "tools": [ @@ -129,7 +131,7 @@ def test_update_assistant(self, collection_id, action_id, assistant_id): "description": "test for openai", "memory": {"type": "naive"}, "retrievals": [{"type": "collection", "id": collection_id}], - "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}, + "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}, "tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] } @@ -408,7 +410,8 @@ def test_assistant_by_user_message_retrieval_and_stream(self, collection_id): "retrieval_configs": { "method": "user_message", "top_k": 1, - "max_tokens": 5000 + "max_tokens": 5000, + "score_threshold": 0.5 } } @@ -457,7 +460,8 @@ def test_assistant_by_memory_retrieval_and_stream(self, collection_id): "retrieval_configs": { "method": "memory", "top_k": 1, - "max_tokens": 5000 + "max_tokens": 5000, + "score_threshold": 0.5 } } @@ -508,7 +512,8 @@ def test_assistant_by_function_call_retrieval_and_stream(self, collection_id): { "method": "function_call", "top_k": 1, - "max_tokens": 5000 + "max_tokens": 5000, + "score_threshold": 0.5 } } diff --git a/test/testcase/test_sync/test_sync_retrieval.py b/test/testcase/test_sync/test_sync_retrieval.py index 834e9dd..8ebe111 100644 --- a/test/testcase/test_sync/test_sync_retrieval.py +++ b/test/testcase/test_sync/test_sync_retrieval.py @@ -1,30 +1,47 @@ import pytest import os -from taskingai.retrieval import Record, TokenTextSplitter -from taskingai.retrieval import list_collections, create_collection, get_collection, update_collection, delete_collection, list_records, create_record, get_record, update_record, delete_record, query_chunks, create_chunk, update_chunk, get_chunk, delete_chunk, list_chunks +from taskingai.retrieval import TokenTextSplitter, TextSplitter +from taskingai.retrieval import ( + list_collections, + create_collection, + get_collection, + update_collection, + delete_collection, + list_records, + create_record, + get_record, + update_record, + delete_record, + query_chunks, + create_chunk, + update_chunk, + get_chunk, + delete_chunk, + list_chunks, +) from taskingai.file import upload_file from test.config import Config from test.common.logger import logger -from test.common.utils import assume_collection_result, assume_record_result, assume_chunk_result, assume_query_chunk_result +from test.common.utils import ( + assume_collection_result, + assume_record_result, + assume_chunk_result, + assume_query_chunk_result, +) @pytest.mark.test_sync class TestCollection: - @pytest.mark.run(order=21) def test_create_collection(self): - # Create a collection. create_dict = { "capacity": 1000, "embedding_model_id": Config.openai_text_embedding_model_id, "name": "test", "description": "description", - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } for x in range(2): res = create_collection(**create_dict) @@ -34,7 +51,6 @@ def test_create_collection(self): @pytest.mark.run(order=22) def test_list_collections(self): - # List collections. nums_limit = 1 @@ -55,7 +71,6 @@ def test_list_collections(self): @pytest.mark.run(order=23) def test_get_collection(self, collection_id): - # Get a collection. res = get_collection(collection_id=collection_id) @@ -65,17 +80,13 @@ def test_get_collection(self, collection_id): @pytest.mark.run(order=24) def test_update_collection(self, collection_id): - # Update a collection. update_collection_data = { "collection_id": collection_id, "name": "test_update", "description": "description_update", - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } res = update_collection(**update_collection_data) res_dict = vars(res) @@ -83,10 +94,9 @@ def test_update_collection(self, collection_id): @pytest.mark.run(order=80) def test_delete_collection(self): - # List collections. - old_res = list_collections(order="desc", limit=100, after=None, before=None) + old_res = list_collections(order="desc", limit=100, after=None, before=None) old_nums = len(old_res) for index, collection in enumerate(old_res): @@ -95,8 +105,8 @@ def test_delete_collection(self): # Delete a collection. delete_collection(collection_id=collection_id) - if index == old_nums-1: - new_collections = list_collections(order="desc", limit=100, after=None, before=None) + if index == old_nums - 1: + new_collections = list_collections(order="desc", limit=100, after=None, before=None) # List collections. @@ -106,14 +116,15 @@ def test_delete_collection(self): @pytest.mark.test_sync class TestRecord: - text_splitter_list = [ - { - "type": "token", # "type": "token - "chunk_size": 100, - "chunk_overlap": 10 - }, - TokenTextSplitter(chunk_size=200, chunk_overlap=20) + # { + # "type": "token", + # "chunk_size": 100, + # "chunk_overlap": 10 + # }, + # TokenTextSplitter(chunk_size=200, chunk_overlap=20), + {"type": "separator", "chunk_size": 100, "chunk_overlap": 10, "separators": [".", "!", "?"]}, + TextSplitter(type="separator", chunk_size=200, chunk_overlap=20, separators=[".", "!", "?"]), ] upload_file_data_list = [] @@ -122,17 +133,14 @@ class TestRecord: for file in files: filepath = os.path.join(base_path, "files", file) if os.path.isfile(filepath): - upload_file_dict = { - "purpose": "record_file" - } + upload_file_dict = {"purpose": "record_file"} upload_file_dict.update({"file": open(filepath, "rb")}) upload_file_data_list.append(upload_file_dict) @pytest.mark.run(order=31) - def test_create_record_by_text(self, collection_id): - + @pytest.mark.parametrize("text_splitter", text_splitter_list) + def test_create_record_by_text(self, collection_id, text_splitter): # Create a text record. - text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." create_record_data = { "type": "text", @@ -140,26 +148,14 @@ def test_create_record_by_text(self, collection_id): "collection_id": collection_id, "content": text, "text_splitter": text_splitter, - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } - for x in range(2): - if x == 0: - create_record_data.update( - {"text_splitter": { - "type": "token", - "chunk_size": 100, - "chunk_overlap": 10 - }}) - res = create_record(**create_record_data) - res_dict = vars(res) - assume_record_result(create_record_data, res_dict) + res = create_record(**create_record_data) + res_dict = vars(res) + assume_record_result(create_record_data, res_dict) @pytest.mark.run(order=31) def test_create_record_by_web(self, collection_id): - # Create a web record. text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=20) create_record_data = { @@ -168,10 +164,7 @@ def test_create_record_by_web(self, collection_id): "collection_id": collection_id, "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", "text_splitter": text_splitter, - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } res = create_record(**create_record_data) @@ -181,7 +174,6 @@ def test_create_record_by_web(self, collection_id): @pytest.mark.run(order=31) @pytest.mark.parametrize("upload_file_data", upload_file_data_list[:2]) def test_create_record_by_file(self, collection_id, upload_file_data): - # upload file upload_file_res = upload_file(**upload_file_data) upload_file_dict = vars(upload_file_res) @@ -195,10 +187,7 @@ def test_create_record_by_file(self, collection_id, upload_file_data): "collection_id": collection_id, "file_id": file_id, "text_splitter": text_splitter, - "metadata": { - "key1": "value1", - "key2": "value2" - } + "metadata": {"key1": "value1", "key2": "value2"}, } res = create_record(**create_record_data) @@ -207,7 +196,6 @@ def test_create_record_by_file(self, collection_id, upload_file_data): @pytest.mark.run(order=32) def test_list_records(self, collection_id): - # List records. nums_limit = 1 @@ -231,14 +219,13 @@ def test_list_records(self, collection_id): @pytest.mark.run(order=33) def test_get_record(self, collection_id): - # list records records = list_records(collection_id=collection_id) for record in records: record_id = record.record_id res = get_record(collection_id=collection_id, record_id=record_id) - logger.info(f'get record response: {res}') + logger.info(f"get record response: {res}") res_dict = vars(res) pytest.assume(res_dict["collection_id"] == collection_id) pytest.assume(res_dict["record_id"] == record_id) @@ -247,7 +234,6 @@ def test_get_record(self, collection_id): @pytest.mark.run(order=34) @pytest.mark.parametrize("text_splitter", text_splitter_list) def test_update_record_by_text(self, collection_id, record_id, text_splitter): - # Update a record. update_record_data = { @@ -257,7 +243,7 @@ def test_update_record_by_text(self, collection_id, record_id, text_splitter): "record_id": record_id, "content": "TaskingAI is an AI-native application development platform that unifies modules like Model, Retrieval, Assistant, and Tool into one seamless ecosystem, streamlining the creation and deployment of applications for developers.", "text_splitter": text_splitter, - "metadata": {"test": "test"} + "metadata": {"test": "test"}, } res = update_record(**update_record_data) res_dict = vars(res) @@ -266,7 +252,6 @@ def test_update_record_by_text(self, collection_id, record_id, text_splitter): @pytest.mark.run(order=34) @pytest.mark.parametrize("text_splitter", text_splitter_list) def test_update_record_by_web(self, collection_id, record_id, text_splitter): - # Update a record. update_record_data = { @@ -276,7 +261,7 @@ def test_update_record_by_web(self, collection_id, record_id, text_splitter): "record_id": record_id, "url": "https://docs.tasking.ai/docs/guide/getting_started/overview/", "text_splitter": text_splitter, - "metadata": {"test": "test"} + "metadata": {"test": "test"}, } res = update_record(**update_record_data) res_dict = vars(res) @@ -285,7 +270,6 @@ def test_update_record_by_web(self, collection_id, record_id, text_splitter): @pytest.mark.run(order=34) @pytest.mark.parametrize("upload_file_data", upload_file_data_list[2:3]) def test_update_record_by_file(self, collection_id, record_id, upload_file_data): - # upload file upload_file_res = upload_file(**upload_file_data) upload_file_dict = vars(upload_file_res) @@ -302,7 +286,7 @@ def test_update_record_by_file(self, collection_id, record_id, upload_file_data) "record_id": record_id, "file_id": file_id, "text_splitter": text_splitter, - "metadata": {"test": "test"} + "metadata": {"test": "test"}, } res = update_record(**update_record_data) res_dict = vars(res) @@ -310,11 +294,9 @@ def test_update_record_by_file(self, collection_id, record_id, upload_file_data) @pytest.mark.run(order=79) def test_delete_record(self, collection_id): - # List records. - records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, - before=None) + records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, before=None) old_nums = len(records) for index, record in enumerate(records): record_id = record.record_id @@ -324,9 +306,8 @@ def test_delete_record(self, collection_id): delete_record(collection_id=collection_id, record_id=record_id) # List records. - if index == old_nums-1: - new_records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, - before=None) + if index == old_nums - 1: + new_records = list_records(collection_id=collection_id, order="desc", limit=20, after=None, before=None) new_nums = len(new_records) pytest.assume(new_nums == 0) @@ -334,31 +315,42 @@ def test_delete_record(self, collection_id): @pytest.mark.test_sync class TestChunk: - - chunk_list = ["chunk_id", "record_id", "collection_id", "content", "metadata", "num_tokens", "score", "updated_timestamp","created_timestamp"] + chunk_list = [ + "chunk_id", + "record_id", + "collection_id", + "content", + "metadata", + "num_tokens", + "score", + "updated_timestamp", + "created_timestamp", + ] chunk_keys = set(chunk_list) @pytest.mark.run(order=41) def test_query_chunks(self, collection_id): - # Query chunks. query_text = "Machine learning" top_k = 1 - res = query_chunks(collection_id=collection_id, query_text=query_text, top_k=top_k, max_tokens=20000) + res = query_chunks( + collection_id=collection_id, query_text=query_text, top_k=top_k, max_tokens=20000, score_threshold=0.04 + ) pytest.assume(len(res) == top_k) for chunk in res: chunk_dict = vars(chunk) assume_query_chunk_result(query_text, chunk_dict) pytest.assume(chunk_dict.keys() == self.chunk_keys) + pytest.assume(chunk_dict["score"] >= 0.04) @pytest.mark.run(order=42) def test_create_chunk(self, collection_id): - # Create a chunk. create_chunk_data = { "collection_id": collection_id, - "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data."} + "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", + } res = create_chunk(**create_chunk_data) res_dict = vars(res) pytest.assume(res_dict.keys() == self.chunk_keys) @@ -366,7 +358,6 @@ def test_create_chunk(self, collection_id): @pytest.mark.run(order=43) def test_list_chunks(self, collection_id): - # List chunks. nums_limit = 1 @@ -390,14 +381,13 @@ def test_list_chunks(self, collection_id): @pytest.mark.run(order=44) def test_get_chunk(self, collection_id): - # list chunks chunks = list_chunks(collection_id=collection_id) for chunk in chunks: chunk_id = chunk.chunk_id res = get_chunk(collection_id=collection_id, chunk_id=chunk_id) - logger.info(f'get chunk response: {res}') + logger.info(f"get chunk response: {res}") res_dict = vars(res) pytest.assume(res_dict["collection_id"] == collection_id) pytest.assume(res_dict["chunk_id"] == chunk_id) @@ -405,14 +395,13 @@ def test_get_chunk(self, collection_id): @pytest.mark.run(order=45) def test_update_chunk(self, collection_id, chunk_id): - # Update a chunk. update_chunk_data = { "collection_id": collection_id, "chunk_id": chunk_id, "content": "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data.", - "metadata": {"test": "test"} + "metadata": {"test": "test"}, } res = update_chunk(**update_chunk_data) res_dict = vars(res) @@ -421,7 +410,6 @@ def test_update_chunk(self, collection_id, chunk_id): @pytest.mark.run(order=46) def test_delete_chunk(self, collection_id): - # List chunks. chunks = list_chunks(collection_id=collection_id, limit=5) diff --git a/test_requirements.txt b/test_requirements.txt index ac239c9..75c5f70 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -6,7 +6,7 @@ randomize>=0.13 pytest==7.4.4 allure-pytest==2.13.5 pytest-ordering==0.6 -pytest-xdist==3.5.0 +pytest-xdist==3.6.1 PyYAML==6.0.1 pytest-assume==2.4.3 pytest-asyncio==0.23.6