Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade milvus parallel embed usage #408

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/en/Best Practice/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ docs = Document()

# (1)
docs.create_node_group(name='block',
transform=lambda d: '\n'.split(d))
transform=lambda d: d.split('\n'))

# (2)
docs.create_node_group(name='doc-summary',
transform=lambda d: summary_llm(d))

# (3)
docs.create_node_group(name='sentence',
transform=lambda b: ''.split(b),
transform=lambda b: b.split(''),
parent='block')

# (4)
Expand Down
4 changes: 2 additions & 2 deletions docs/zh/Best Practice/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ docs = Document()

# (1)
docs.create_node_group(name='block',
transform=lambda d: '\n'.split(d))
transform=lambda d: d.split('\n'))

# (2)
docs.create_node_group(name='doc-summary',
transform=lambda d: summary_llm(d))

# (3)
docs.create_node_group(name='sentence',
transform=lambda b: ''.split(b),
transform=lambda b: b.split(''),
parent='block')

# (4)
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def load_data(self, input_files: Optional[List[str]] = None, metadates: Optional
nodes.append(doc)
if not nodes:
LOG.warning(
f"No nodes load from path {self.input_files}, please check your data path."
f"No nodes load from path {input_files}, please check your data path."
)
return nodes
14 changes: 9 additions & 5 deletions lazyllm/tools/rag/milvus_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from collections import defaultdict
from typing import Dict, List, Optional, Union, Callable, Set
from lazyllm.thirdparty import pymilvus
from .doc_node import DocNode
Expand All @@ -12,6 +13,8 @@
from .data_type import DataType
from lazyllm.common import override, obj2str, str2obj

MILVUS_UPSERT_BATCH_SIZE = 500

class MilvusStore(StoreBase):
# we define these variables as members so that pymilvus is not imported until MilvusStore is instantiated.
def _def_constants(self) -> None:
Expand Down Expand Up @@ -156,13 +159,14 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla

@override
def update_nodes(self, nodes: List[DocNode]) -> None:
parallel_do_embedding(self._embed, [], nodes, self._group_embed_keys)
group_embed_dict = defaultdict(list)
for node in nodes:
embed_keys = self._group_embed_keys.get(node._group)
if embed_keys:
parallel_do_embedding(self._embed, embed_keys, [node])
data = self._serialize_node_partial(node)
self._client.upsert(collection_name=node._group, data=[data])

group_embed_dict[node._group].append(data)
for group_name, data in group_embed_dict.items():
for i in range(0, MILVUS_UPSERT_BATCH_SIZE, len(data)):
self._client.upsert(collection_name=group_name, data=data[i:i + MILVUS_UPSERT_BATCH_SIZE])
self._map_store.update_nodes(nodes)

@override
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def transform(self, node: DocNode, **kwargs) -> List[Union[str, DocNode]]:
You should not have any unnecessary output. Lets begin:
"""),
cn=dict(summary="""
zh=dict(summary="""
## 角色:文本摘要
你是一个文本摘要引擎,负责分析用户输入的文本,并根据请求任务提供简洁的摘要。
Expand Down
6 changes: 5 additions & 1 deletion lazyllm/tools/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,11 +681,15 @@ def save_files_in_threads(

# returns a list of modified nodes
def parallel_do_embedding(embed: Dict[str, Callable], embed_keys: Optional[Union[List[str], Set[str]]],
nodes: List[DocNode]) -> List[DocNode]:
nodes: List[DocNode], group_embed_keys: Dict[str, List[str]] = None) -> List[DocNode]:
modified_nodes = []
with ThreadPoolExecutor(config["max_embedding_workers"]) as executor:
futures = []
for node in nodes:
if group_embed_keys:
embed_keys = group_embed_keys.get(node._group)
if not embed_keys:
continue
miss_keys = node.has_missing_embedding(embed_keys)
if not miss_keys:
continue
Expand Down
Loading