Skip to content

Commit

Permalink
chore: remove Langchain tools import (langgenius#3407)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Apr 12, 2024
1 parent e1f0abe commit 5c81f6c
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 73 deletions.
2 changes: 1 addition & 1 deletion api/core/rag/extractor/blod/blod.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class BlobLoader(ABC):
def yield_blobs(
self,
) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object.
"""A lazy loader for raw data represented by Blob object.
Returns:
A generator over blobs
Expand Down
4 changes: 2 additions & 2 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, cast

from flask import Flask, current_app
from langchain.tools import BaseTool

from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
Expand All @@ -19,6 +18,7 @@
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment
Expand Down Expand Up @@ -383,7 +383,7 @@ def to_dataset_retriever_tool(self, tenant_id: str,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[BaseTool]]:
-> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
Expand Down
25 changes: 25 additions & 0 deletions api/core/rag/retrieval/output_parser/react_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import NamedTuple, Union


@dataclass
class ReactAction:
"""A full description of an action for an ReactAction to execute."""

tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""


class ReactFinish(NamedTuple):
"""The final return value of an ReactFinish."""

return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""
18 changes: 7 additions & 11 deletions api/core/rag/retrieval/output_parser/structured_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,24 @@
import re
from typing import Union

from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser
from langchain.agents.structured_chat.output_parser import logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish


class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
class StructuredChatOutputParser:
def parse(self, text: str) -> Union[ReactAction, ReactFinish]:
try:
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
return ReactFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
return ReactAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
return ReactFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}")
raise ValueError(f"Could not parse LLM output: {text}")
35 changes: 13 additions & 22 deletions api/core/rag/retrieval/router/multi_dataset_react_route.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from collections.abc import Generator, Sequence
from typing import Optional, Union

from langchain import PromptTemplate
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.schema import AgentAction
from typing import Union

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm.llm_node import LLMNode

PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""

SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Thought:"""

FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Expand Down Expand Up @@ -86,7 +87,6 @@ def _react_invoke(
tenant_id: str,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
if model_config.mode == "chat":
Expand All @@ -95,15 +95,13 @@ def _react_invoke(
tools=tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
)
else:
prompt = self.create_completion_prompt(
tools=tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=None
)
stop = ['Observation:']
# handle invoke result
Expand All @@ -127,9 +125,9 @@ def _react_invoke(
tenant_id=tenant_id
)
output_parser = StructuredChatOutputParser()
agent_decision = output_parser.parse(result_text)
if isinstance(agent_decision, AgentAction):
return agent_decision.tool
react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction):
return react_decision.tool
return None

def _invoke_llm(self, completion_param: dict,
Expand All @@ -139,7 +137,6 @@ def _invoke_llm(self, completion_param: dict,
) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
Expand Down Expand Up @@ -197,7 +194,6 @@ def create_chat_prompt(
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> list[ChatModelMessage]:
tool_strings = []
Expand Down Expand Up @@ -227,16 +223,13 @@ def create_completion_prompt(
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
) -> CompletionModelPromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
Expand All @@ -249,6 +242,4 @@ def create_completion_prompt(
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
return CompletionModelPromptTemplate(text=template)
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import threading
from typing import Optional

from flask import Flask, current_app
from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

Expand All @@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")


class DatasetMultiRetrieverTool(BaseTool):
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
tenant_id: str
dataset_ids: list[str]
top_k: int = 2
score_threshold: Optional[float] = None
reranking_provider_name: str
reranking_model_name: str
return_resource: bool
retriever_from: str
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []


@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
Expand Down Expand Up @@ -149,9 +143,6 @@ def _run(self, query: str) -> str:

return str("\n".join(document_context_list))

async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()

def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler]):
with flask_app.app_context():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import abstractmethod
from typing import Any, Optional

from msal_extensions.persistence import ABC
from pydantic import BaseModel

from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler


class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str

class Config:
arbitrary_types_allowed = True

@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""
19 changes: 4 additions & 15 deletions api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Optional

from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

Expand All @@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")


class DatasetRetrieverTool(BaseTool):
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "

tenant_id: str
dataset_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str


@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
Expand Down Expand Up @@ -153,7 +145,4 @@ def _run(self, query: str) -> str:
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)

return str("\n".join(document_context_list))

async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
return str("\n".join(document_context_list))
19 changes: 9 additions & 10 deletions api/core/tools/tool/dataset_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any

from langchain.tools import BaseTool

from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
Expand All @@ -14,11 +12,12 @@
ToolParameter,
ToolProviderType,
)
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.tool import Tool


class DatasetRetrieverTool(Tool):
langchain_tool: BaseTool
retrival_tool: DatasetRetrieverBaseTool

@staticmethod
def get_dataset_tools(tenant_id: str,
Expand All @@ -43,7 +42,7 @@ def get_dataset_tools(tenant_id: str,
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
langchain_tools = feature.to_dataset_retriever_tool(
retrival_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
Expand All @@ -54,17 +53,17 @@ def get_dataset_tools(tenant_id: str,
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode

# convert langchain tools to Tools
# convert retrival tools to Tools
tools = []
for langchain_tool in langchain_tools:
for retrival_tool in retrival_tools:
tool = DatasetRetrieverTool(
langchain_tool=langchain_tool,
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
retrival_tool=retrival_tool,
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''),
llm=langchain_tool.description),
llm=retrival_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)

Expand Down Expand Up @@ -96,7 +95,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
return self.create_text_message(text='please input query')

# invoke dataset retriever tool
result = self.langchain_tool._run(query=query)
result = self.retrival_tool._run(query=query)

return self.create_text_message(text=result)

Expand Down

0 comments on commit 5c81f6c

Please sign in to comment.