chore: remove Langchain tools import (#3407)

This commit is contained in:
Jyong 2024-04-12 16:26:09 +08:00 committed by GitHub
parent c227f3d985
commit 0737e930cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 98 additions and 73 deletions

View File

@ -159,7 +159,7 @@ class BlobLoader(ABC):
def yield_blobs( def yield_blobs(
self, self,
) -> Iterable[Blob]: ) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object. """A lazy loader for raw data represented by Blob object.
Returns: Returns:
A generator over blobs A generator over blobs

View File

@ -2,7 +2,6 @@ import threading
from typing import Optional, cast from typing import Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
@ -19,6 +18,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rerank.rerank import RerankRunner from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Dataset, DatasetQuery, DocumentSegment
@ -383,7 +383,7 @@ class DatasetRetrieval:
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \ hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[BaseTool]]: -> Optional[list[DatasetRetrieverBaseTool]]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id :param tenant_id: tenant id

View File

@ -0,0 +1,25 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import NamedTuple, Union
@dataclass
class ReactAction:
"""A full description of an action for an ReactAction to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""
class ReactFinish(NamedTuple):
"""The final return value of an ReactFinish."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""

View File

@ -2,28 +2,24 @@ import json
import re import re
from typing import Union from typing import Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
from langchain.agents.structured_chat.output_parser import logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser): class StructuredChatOutputParser:
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[ReactAction, ReactFinish]:
try: try:
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL) action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None: if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False) response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list): if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0] response = response[0]
if response["action"] == "Final Answer": if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text) return ReactFinish({"output": response["action_input"]}, text)
else: else:
return AgentAction( return ReactAction(
response["action"], response.get("action_input", {}), text response["action"], response.get("action_input", {}), text
) )
else: else:
return AgentFinish({"output": text}, text) return ReactFinish({"output": text}, text)
except Exception as e: except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") raise ValueError(f"Could not parse LLM output: {text}")

View File

@ -1,20 +1,21 @@
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Union
from langchain import PromptTemplate
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.schema import AgentAction
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.llm_node import LLMNode
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Thought:"""
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names} Valid "action" values: "Final Answer" or {tool_names}
@ -86,7 +87,6 @@ class ReactMultiDatasetRouter:
tenant_id: str, tenant_id: str,
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]: ) -> Union[str, None]:
if model_config.mode == "chat": if model_config.mode == "chat":
@ -95,7 +95,6 @@ class ReactMultiDatasetRouter:
tools=tools, tools=tools,
prefix=prefix, prefix=prefix,
suffix=suffix, suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions, format_instructions=format_instructions,
) )
else: else:
@ -103,7 +102,6 @@ class ReactMultiDatasetRouter:
tools=tools, tools=tools,
prefix=prefix, prefix=prefix,
format_instructions=format_instructions, format_instructions=format_instructions,
input_variables=None
) )
stop = ['Observation:'] stop = ['Observation:']
# handle invoke result # handle invoke result
@ -127,9 +125,9 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id tenant_id=tenant_id
) )
output_parser = StructuredChatOutputParser() output_parser = StructuredChatOutputParser()
agent_decision = output_parser.parse(result_text) react_decision = output_parser.parse(result_text)
if isinstance(agent_decision, AgentAction): if isinstance(react_decision, ReactAction):
return agent_decision.tool return react_decision.tool
return None return None
def _invoke_llm(self, completion_param: dict, def _invoke_llm(self, completion_param: dict,
@ -139,7 +137,6 @@ class ReactMultiDatasetRouter:
) -> tuple[str, LLMUsage]: ) -> tuple[str, LLMUsage]:
""" """
Invoke large language model Invoke large language model
:param node_data: node data
:param model_instance: model instance :param model_instance: model instance
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:param stop: stop :param stop: stop
@ -197,7 +194,6 @@ class ReactMultiDatasetRouter:
tools: Sequence[PromptMessageTool], tools: Sequence[PromptMessageTool],
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
) -> list[ChatModelMessage]: ) -> list[ChatModelMessage]:
tool_strings = [] tool_strings = []
@ -227,16 +223,13 @@ class ReactMultiDatasetRouter:
tools: Sequence[PromptMessageTool], tools: Sequence[PromptMessageTool],
prefix: str = PREFIX, prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None, ) -> CompletionModelPromptTemplate:
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent. """Create prompt in the style of the zero shot agent.
Args: Args:
tools: List of tools the agent will have access to, used to format the tools: List of tools the agent will have access to, used to format the
prompt. prompt.
prefix: String to put before the list of tools. prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns: Returns:
A PromptTemplate with the template assembled from the pieces here. A PromptTemplate with the template assembled from the pieces here.
""" """
@ -249,6 +242,4 @@ Thought: {agent_scratchpad}
tool_names = ", ".join([tool.name for tool in tools]) tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names) format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None: return CompletionModelPromptTemplate(text=template)
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)

View File

@ -1,8 +1,6 @@
import threading import threading
from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from langchain.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@ -10,6 +8,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank") query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(BaseTool): class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset.""" """Tool for querying multi dataset."""
name: str = "dataset_" name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. " description: str = "dataset multi retriever and rerank. "
tenant_id: str
dataset_ids: list[str] dataset_ids: list[str]
top_k: int = 2
score_threshold: Optional[float] = None
reranking_provider_name: str reranking_provider_name: str
reranking_model_name: str reranking_model_name: str
return_resource: bool
retriever_from: str
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
@classmethod @classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
return str("\n".join(document_context_list)) return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler]): hit_callbacks: list[DatasetIndexToolCallbackHandler]):
with flask_app.app_context(): with flask_app.app_context():

View File

@ -0,0 +1,34 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from pydantic import BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
class Config:
arbitrary_types_allowed = True
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@ -1,10 +1,8 @@
from typing import Optional
from langchain.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool): class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset.""" """Tool for querying a Dataset."""
name: str = "dataset" name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. " description: str = "use this to retrieve a dataset. "
tenant_id: str
dataset_id: str dataset_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
@ -153,7 +145,4 @@ class DatasetRetrieverTool(BaseTool):
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list) hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list)) return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()

View File

@ -1,7 +1,5 @@
from typing import Any from typing import Any
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@ -14,11 +12,12 @@ from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
) )
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
class DatasetRetrieverTool(Tool): class DatasetRetrieverTool(Tool):
langchain_tool: BaseTool retrival_tool: DatasetRetrieverBaseTool
@staticmethod @staticmethod
def get_dataset_tools(tenant_id: str, def get_dataset_tools(tenant_id: str,
@ -43,7 +42,7 @@ class DatasetRetrieverTool(Tool):
# Agent only support SINGLE mode # Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
langchain_tools = feature.to_dataset_retriever_tool( retrival_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=retrieve_config, retrieve_config=retrieve_config,
@ -54,17 +53,17 @@ class DatasetRetrieverTool(Tool):
# restore retrieve strategy # restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools # convert retrival tools to Tools
tools = [] tools = []
for langchain_tool in langchain_tools: for retrival_tool in retrival_tools:
tool = DatasetRetrieverTool( tool = DatasetRetrieverTool(
langchain_tool=langchain_tool, retrival_tool=retrival_tool,
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')), identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[], parameters=[],
is_team_authorization=True, is_team_authorization=True,
description=ToolDescription( description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''), human=I18nObject(en_US='', zh_Hans=''),
llm=langchain_tool.description), llm=retrival_tool.description),
runtime=DatasetRetrieverTool.Runtime() runtime=DatasetRetrieverTool.Runtime()
) )
@ -96,7 +95,7 @@ class DatasetRetrieverTool(Tool):
return self.create_text_message(text='please input query') return self.create_text_message(text='please input query')
# invoke dataset retriever tool # invoke dataset retriever tool
result = self.langchain_tool._run(query=query) result = self.retrival_tool._run(query=query)
return self.create_text_message(text=result) return self.create_text_message(text=result)