Skip to content

Commit

Permalink
test: add tests for assistant, record, chunk
Browse files Browse the repository at this point in the history
  • Loading branch information
taskingaijc authored and Dttbd committed May 7, 2024
1 parent 5f858c9 commit 66e228e
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 118 deletions.
2 changes: 1 addition & 1 deletion taskingai/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__title__ = "taskingai"
__version__ = "0.2.3"
__version__ = "0.2.4"
9 changes: 5 additions & 4 deletions test/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,11 @@ def assume_assistant_result(assistant_dict: dict, res: dict):
if key == 'system_prompt_template' and isinstance(value, str):
pytest.assume(res[key] == [assistant_dict[key]])
elif key in ['retrieval_configs']:
if isinstance(value, dict):
pytest.assume(vars(res[key]) == assistant_dict[key])
else:
pytest.assume(res[key] == assistant_dict[key])
continue
# if isinstance(value, dict):
# pytest.assume(vars(res[key]) == assistant_dict[key])
# else:
# pytest.assume(res[key] == assistant_dict[key])
elif key in ["memory", "tools", "retrievals"]:
continue
else:
Expand Down
16 changes: 11 additions & 5 deletions test/testcase/test_async/test_async_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def test_a_create_assistant(self):
method="memory",
top_k=1,
max_tokens=5000,
score_threshold=0.5

),
"tools": [
Expand All @@ -54,7 +55,7 @@ async def test_a_create_assistant(self):
if i == 0:
assistant_dict.update({"memory": {"type": "naive"}})
assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]})
assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}})
assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}})
assistant_dict.update({"tools": [{"type": "action", "id": self.action_id},
{"type": "plugin", "id": "open_weather/get_hourly_forecast"}]})
res = await a_create_assistant(**assistant_dict)
Expand Down Expand Up @@ -119,6 +120,7 @@ async def test_a_update_assistant(self):
method="memory",
top_k=2,
max_tokens=4000,
score_threshold=0.5

),
"tools": [
Expand All @@ -137,7 +139,7 @@ async def test_a_update_assistant(self):
"description": "test for openai",
"memory": {"type": "naive"},
"retrievals": [{"type": "collection", "id": self.collection_id}],
"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000},
"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5},
"tools": [{"type": "action", "id": self.action_id},
{"type": "plugin", "id": "open_weather/get_hourly_forecast"}]

Expand Down Expand Up @@ -365,6 +367,7 @@ async def test_a_generate_message_by_stream(self):
method="memory",
top_k=1,
max_tokens=5000,
score_threshold=0.04

),
"tools": [
Expand Down Expand Up @@ -435,7 +438,8 @@ async def test_a_assistant_by_user_message_retrieval_and_stream(self):
"retrieval_configs": {
"method": "user_message",
"top_k": 1,
"max_tokens": 5000
"max_tokens": 5000,
"score_threshold": 0.5
}
}

Expand Down Expand Up @@ -482,7 +486,8 @@ async def test_a_assistant_by_memory_retrieval_and_stream(self):
"retrieval_configs": {
"method": "memory",
"top_k": 1,
"max_tokens": 5000
"max_tokens": 5000,
"score_threshold": 0.5

}
}
Expand Down Expand Up @@ -534,7 +539,8 @@ async def test_a_assistant_by_function_call_retrieval_and_stream(self):
{
"method": "function_call",
"top_k": 1,
"max_tokens": 5000
"max_tokens": 5000,
"score_threshold": 0.5
}
}

Expand Down
29 changes: 12 additions & 17 deletions test/testcase/test_async/test_async_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

@pytest.mark.test_async
class TestCollection(Base):

@pytest.mark.run(order=21)
@pytest.mark.asyncio
async def test_a_create_collection(self):
Expand Down Expand Up @@ -101,10 +100,11 @@ async def test_a_delete_collection(self):

@pytest.mark.test_async
class TestRecord(Base):

text_splitter_list = [
{"type": "token", "chunk_size": 100, "chunk_overlap": 10},
TokenTextSplitter(chunk_size=200, chunk_overlap=20),
# {"type": "token", "chunk_size": 100, "chunk_overlap": 10},
# TokenTextSplitter(chunk_size=200, chunk_overlap=20),
{"type": "separator", "chunk_size": 100, "chunk_overlap": 10, "separators": [".", "!", "?"]},
TextSplitter(type="separator", chunk_size=200, chunk_overlap=20, separators=[".", "!", "?"]),
]

upload_file_data_list = []
Expand All @@ -120,8 +120,8 @@ class TestRecord(Base):

@pytest.mark.run(order=31)
@pytest.mark.asyncio
async def test_a_create_record_by_text(self):
text_splitter = TokenTextSplitter(chunk_size=200, chunk_overlap=100)
@pytest.mark.parametrize("text_splitter", text_splitter_list)
async def test_a_create_record_by_text(self, text_splitter):
text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data."
create_record_data = {
"type": "text",
Expand All @@ -131,16 +131,10 @@ async def test_a_create_record_by_text(self):
"text_splitter": text_splitter,
"metadata": {"key1": "value1", "key2": "value2"},
}

for x in range(2):
# Create a record.
if x == 0:
create_record_data.update({"text_splitter": {"type": "token", "chunk_size": 100, "chunk_overlap": 10}})

res = await a_create_record(**create_record_data)
res_dict = vars(res)
assume_record_result(create_record_data, res_dict)
Base.record_id = res_dict["record_id"]
res = await a_create_record(**create_record_data)
res_dict = vars(res)
assume_record_result(create_record_data, res_dict)
Base.record_id = res_dict["record_id"]

@pytest.mark.run(order=31)
@pytest.mark.asyncio
Expand Down Expand Up @@ -332,13 +326,14 @@ async def test_a_query_chunks(self):
query_text = "Machine learning"
top_k = 1
res = await a_query_chunks(
collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000
collection_id=self.collection_id, query_text=query_text, top_k=top_k, max_tokens=20000, score_threshold=0.04
)
pytest.assume(len(res) == top_k)
for chunk in res:
chunk_dict = vars(chunk)
assume_query_chunk_result(query_text, chunk_dict)
pytest.assume(chunk_dict.keys() == self.chunk_keys)
pytest.assume(chunk_dict["score"] >= 0.04)

@pytest.mark.run(order=42)
@pytest.mark.asyncio
Expand Down
15 changes: 10 additions & 5 deletions test/testcase/test_sync/test_sync_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_create_assistant(self, collection_id, action_id):
method="memory",
top_k=1,
max_tokens=5000,
score_threshold=0.5

),
"tools": [
Expand All @@ -50,7 +51,7 @@ def test_create_assistant(self, collection_id, action_id):
if i == 0:
assistant_dict.update({"memory": {"type": "naive"}})
assistant_dict.update({"retrievals": [{"type": "collection", "id": collection_id}]})
assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}})
assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5}})
assistant_dict.update({"tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]})

res = create_assistant(**assistant_dict)
Expand Down Expand Up @@ -111,6 +112,7 @@ def test_update_assistant(self, collection_id, action_id, assistant_id):
method="memory",
top_k=2,
max_tokens=4000,
score_threshold=0.5

),
"tools": [
Expand All @@ -129,7 +131,7 @@ def test_update_assistant(self, collection_id, action_id, assistant_id):
"description": "test for openai",
"memory": {"type": "naive"},
"retrievals": [{"type": "collection", "id": collection_id}],
"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000},
"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000, "score_threshold": 0.5},
"tools": [{"type": "action", "id": action_id}, {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]

}
Expand Down Expand Up @@ -408,7 +410,8 @@ def test_assistant_by_user_message_retrieval_and_stream(self, collection_id):
"retrieval_configs": {
"method": "user_message",
"top_k": 1,
"max_tokens": 5000
"max_tokens": 5000,
"score_threshold": 0.5
}
}

Expand Down Expand Up @@ -457,7 +460,8 @@ def test_assistant_by_memory_retrieval_and_stream(self, collection_id):
"retrieval_configs": {
"method": "memory",
"top_k": 1,
"max_tokens": 5000
"max_tokens": 5000,
"score_threshold": 0.5

}
}
Expand Down Expand Up @@ -508,7 +512,8 @@ def test_assistant_by_function_call_retrieval_and_stream(self, collection_id):
{
"method": "function_call",
"top_k": 1,
"max_tokens": 5000
"max_tokens": 5000,
"score_threshold": 0.5
}
}

Expand Down
Loading

0 comments on commit 66e228e

Please sign in to comment.