From 5c81f6c75c5d3eed02ca54c022b015073699ab69 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:26:09 +0800 Subject: [PATCH] chore: remove Langchain tools import (#3407) --- api/core/rag/extractor/blod/blod.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 4 +-- .../retrieval/output_parser/react_output.py | 25 +++++++++++++ .../output_parser/structured_chat.py | 18 ++++------ .../router/multi_dataset_react_route.py | 35 +++++++------------ .../dataset_multi_retriever_tool.py | 15 ++------ .../dataset_retriever_base_tool.py | 34 ++++++++++++++++++ .../dataset_retriever_tool.py | 19 +++------- api/core/tools/tool/dataset_retriever_tool.py | 19 +++++----- 9 files changed, 98 insertions(+), 73 deletions(-) create mode 100644 api/core/rag/retrieval/output_parser/react_output.py create mode 100644 api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py diff --git a/api/core/rag/extractor/blod/blod.py b/api/core/rag/extractor/blod/blod.py index 368946b5e41d1e..8d423e1b3f6236 100644 --- a/api/core/rag/extractor/blod/blod.py +++ b/api/core/rag/extractor/blod/blod.py @@ -159,7 +159,7 @@ class BlobLoader(ABC): def yield_blobs( self, ) -> Iterable[Blob]: - """A lazy loader for raw data represented by LangChain's Blob object. + """A lazy loader for raw data represented by Blob object. Returns: A generator over blobs diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 02bd92c1451eb5..155b8be06c0bb1 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -2,7 +2,6 @@ from typing import Optional, cast from flask import Flask, current_app -from langchain.tools import BaseTool from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity @@ -19,6 +18,7 @@ from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.rerank.rerank import RerankRunner from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, DocumentSegment @@ -383,7 +383,7 @@ def to_dataset_retriever_tool(self, tenant_id: str, return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[BaseTool]]: + -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id diff --git a/api/core/rag/retrieval/output_parser/react_output.py b/api/core/rag/retrieval/output_parser/react_output.py new file mode 100644 index 00000000000000..9a14d417164e62 --- /dev/null +++ b/api/core/rag/retrieval/output_parser/react_output.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import NamedTuple, Union + + +@dataclass +class ReactAction: + """A full description of an action for an ReactAction to execute.""" + + tool: str + """The name of the Tool to execute.""" + tool_input: Union[str, dict] + """The input to pass in to the Tool.""" + log: str + """Additional information to log about the action.""" + + +class ReactFinish(NamedTuple): + """The final return value of an ReactFinish.""" + + return_values: dict + """Dictionary of return values.""" + log: str + """Additional information to log about the return value""" diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py index c2d748d8f6e31b..60770bd4c6e06a 100644 --- a/api/core/rag/retrieval/output_parser/structured_chat.py +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -2,28 +2,24 @@ import re from typing import Union -from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser -from langchain.agents.structured_chat.output_parser import logger -from langchain.schema import AgentAction, AgentFinish, OutputParserException +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -class StructuredChatOutputParser(LCStructuredChatOutputParser): - def parse(self, text: str) -> Union[AgentAction, AgentFinish]: +class StructuredChatOutputParser: + def parse(self, text: str) -> Union[ReactAction, ReactFinish]: try: action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL) if action_match is not None: response = json.loads(action_match.group(2).strip(), strict=False) if isinstance(response, list): - # gpt turbo frequently ignores the directive to emit a single action - logger.warning("Got multiple action responses: %s", response) response = response[0] if response["action"] == "Final Answer": - return AgentFinish({"output": response["action_input"]}, text) + return ReactFinish({"output": response["action_input"]}, text) else: - return AgentAction( + return ReactAction( response["action"], response.get("action_input", {}), text ) else: - return AgentFinish({"output": text}, text) + return ReactFinish({"output": text}, text) except Exception as e: - raise OutputParserException(f"Could not parse LLM output: {text}") + raise ValueError(f"Could not parse LLM output: {text}") diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 0ec01047b32745..5de2a66e2dacb6 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,20 +1,21 @@ from collections.abc import Generator, Sequence -from typing import Optional, Union - -from langchain import PromptTemplate -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX -from langchain.schema import AgentAction +from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from core.workflow.nodes.llm.llm_node import LLMNode +PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" + +SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Thought:""" + FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. Valid "action" values: "Final Answer" or {tool_names} @@ -86,7 +87,6 @@ def _react_invoke( tenant_id: str, prefix: str = PREFIX, suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": @@ -95,7 +95,6 @@ def _react_invoke( tools=tools, prefix=prefix, suffix=suffix, - human_message_template=human_message_template, format_instructions=format_instructions, ) else: @@ -103,7 +102,6 @@ def _react_invoke( tools=tools, prefix=prefix, format_instructions=format_instructions, - input_variables=None ) stop = ['Observation:'] # handle invoke result @@ -127,9 +125,9 @@ def _react_invoke( tenant_id=tenant_id ) output_parser = StructuredChatOutputParser() - agent_decision = output_parser.parse(result_text) - if isinstance(agent_decision, AgentAction): - return agent_decision.tool + react_decision = output_parser.parse(result_text) + if isinstance(react_decision, ReactAction): + return react_decision.tool return None def _invoke_llm(self, completion_param: dict, @@ -139,7 +137,6 @@ def _invoke_llm(self, completion_param: dict, ) -> tuple[str, LLMUsage]: """ Invoke large language model - :param node_data: node data :param model_instance: model instance :param prompt_messages: prompt messages :param stop: stop @@ -197,7 +194,6 @@ def create_chat_prompt( tools: Sequence[PromptMessageTool], prefix: str = PREFIX, suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] @@ -227,16 +223,13 @@ def create_completion_prompt( tools: Sequence[PromptMessageTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - ) -> PromptTemplate: + ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. Args: tools: List of tools the agent will have access to, used to format the prompt. prefix: String to put before the list of tools. - input_variables: List of input variables the final prompt will expect. - Returns: A PromptTemplate with the template assembled from the pieces here. """ @@ -249,6 +242,4 @@ def create_completion_prompt( tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - return PromptTemplate(template=template, input_variables=input_variables) + return CompletionModelPromptTemplate(text=template) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index e977772fd18335..6e11427d58ac2d 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,8 +1,6 @@ import threading -from typing import Optional from flask import Flask, current_app -from langchain.tools import BaseTool from pydantic import BaseModel, Field from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -10,6 +8,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService from core.rerank.rerank import RerankRunner +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel): query: str = Field(..., description="dataset multi retriever and rerank") -class DatasetMultiRetrieverTool(BaseTool): +class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " - tenant_id: str dataset_ids: list[str] - top_k: int = 2 - score_threshold: Optional[float] = None reranking_provider_name: str reranking_model_name: str - return_resource: bool - retriever_from: str - hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): @@ -149,9 +143,6 @@ def _run(self, query: str) -> str: return str("\n".join(document_context_list)) - async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, hit_callbacks: list[DatasetIndexToolCallbackHandler]): with flask_app.app_context(): diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py new file mode 100644 index 00000000000000..1f8478f5541acf --- /dev/null +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -0,0 +1,34 @@ +from abc import abstractmethod +from typing import Any, Optional + +from msal_extensions.persistence import ABC +from pydantic import BaseModel + +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler + + +class DatasetRetrieverBaseTool(BaseModel, ABC): + """Tool for querying a Dataset.""" + name: str = "dataset" + description: str = "use this to retrieve a dataset. " + tenant_id: str + top_k: int = 2 + score_threshold: Optional[float] = None + hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] + return_resource: bool + retriever_from: str + + class Config: + arbitrary_types_allowed = True + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index b99e392de017a3..552174e0bad82d 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,10 +1,8 @@ -from typing import Optional -from langchain.tools import BaseTool from pydantic import BaseModel, Field -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel): query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") -class DatasetRetrieverTool(BaseTool): +class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " - - tenant_id: str dataset_id: str - top_k: int = 2 - score_threshold: Optional[float] = None - hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] - return_resource: bool - retriever_from: str + @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): @@ -153,7 +145,4 @@ def _run(self, query: str) -> str: for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) - return str("\n".join(document_context_list)) - - async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() + return str("\n".join(document_context_list)) \ No newline at end of file diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 421f8a0483379c..e52981b2d14591 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,7 +1,5 @@ from typing import Any -from langchain.tools import BaseTool - from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -14,11 +12,12 @@ ToolParameter, ToolProviderType, ) +from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.tool.tool import Tool class DatasetRetrieverTool(Tool): - langchain_tool: BaseTool + retrival_tool: DatasetRetrieverBaseTool @staticmethod def get_dataset_tools(tenant_id: str, @@ -43,7 +42,7 @@ def get_dataset_tools(tenant_id: str, # Agent only support SINGLE mode original_retriever_mode = retrieve_config.retrieve_strategy retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE - langchain_tools = feature.to_dataset_retriever_tool( + retrival_tools = feature.to_dataset_retriever_tool( tenant_id=tenant_id, dataset_ids=dataset_ids, retrieve_config=retrieve_config, @@ -54,17 +53,17 @@ def get_dataset_tools(tenant_id: str, # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode - # convert langchain tools to Tools + # convert retrival tools to Tools tools = [] - for langchain_tool in langchain_tools: + for retrival_tool in retrival_tools: tool = DatasetRetrieverTool( - langchain_tool=langchain_tool, - identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')), + retrival_tool=retrival_tool, + identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')), parameters=[], is_team_authorization=True, description=ToolDescription( human=I18nObject(en_US='', zh_Hans=''), - llm=langchain_tool.description), + llm=retrival_tool.description), runtime=DatasetRetrieverTool.Runtime() ) @@ -96,7 +95,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(text='please input query') # invoke dataset retriever tool - result = self.langchain_tool._run(query=query) + result = self.retrival_tool._run(query=query) return self.create_text_message(text=result)