chore: remove Langchain tools import (#3407)

This commit is contained in:
Jyong 2024-04-12 16:26:09 +08:00 committed by GitHub
parent c227f3d985
commit 0737e930cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 98 additions and 73 deletions

View File

@ -159,7 +159,7 @@ class BlobLoader(ABC):
def yield_blobs(
self,
) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object.
"""A lazy loader for raw data represented by Blob object.
Returns:
A generator over blobs

View File

@ -2,7 +2,6 @@ import threading
from typing import Optional, cast
from flask import Flask, current_app
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
@ -19,6 +18,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, DocumentSegment
@ -383,7 +383,7 @@ class DatasetRetrieval:
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[BaseTool]]:
-> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id

View File

@ -0,0 +1,25 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import NamedTuple, Union
@dataclass
class ReactAction:
"""A full description of an action for an ReactAction to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""
class ReactFinish(NamedTuple):
"""The final return value of an ReactFinish."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""

View File

@ -2,28 +2,24 @@ import json
import re
from typing import Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser
from langchain.agents.structured_chat.output_parser import logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
class StructuredChatOutputParser:
def parse(self, text: str) -> Union[ReactAction, ReactFinish]:
try:
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
return ReactFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
return ReactAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
return ReactFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}")
raise ValueError(f"Could not parse LLM output: {text}")

View File

@ -1,20 +1,21 @@
from collections.abc import Generator, Sequence
from typing import Optional, Union
from langchain import PromptTemplate
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.schema import AgentAction
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm.llm_node import LLMNode
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Thought:"""
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
@ -86,7 +87,6 @@ class ReactMultiDatasetRouter:
tenant_id: str,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
if model_config.mode == "chat":
@ -95,7 +95,6 @@ class ReactMultiDatasetRouter:
tools=tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
)
else:
@ -103,7 +102,6 @@ class ReactMultiDatasetRouter:
tools=tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=None
)
stop = ['Observation:']
# handle invoke result
@ -127,9 +125,9 @@ class ReactMultiDatasetRouter:
tenant_id=tenant_id
)
output_parser = StructuredChatOutputParser()
agent_decision = output_parser.parse(result_text)
if isinstance(agent_decision, AgentAction):
return agent_decision.tool
react_decision = output_parser.parse(result_text)
if isinstance(react_decision, ReactAction):
return react_decision.tool
return None
def _invoke_llm(self, completion_param: dict,
@ -139,7 +137,6 @@ class ReactMultiDatasetRouter:
) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
@ -197,7 +194,6 @@ class ReactMultiDatasetRouter:
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> list[ChatModelMessage]:
tool_strings = []
@ -227,16 +223,13 @@ class ReactMultiDatasetRouter:
tools: Sequence[PromptMessageTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
) -> CompletionModelPromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
@ -249,6 +242,4 @@ Thought: {agent_scratchpad}
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
return CompletionModelPromptTemplate(text=template)

View File

@ -1,8 +1,6 @@
import threading
from typing import Optional
from flask import Flask, current_app
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@ -10,6 +8,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rerank.rerank import RerankRunner
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(BaseTool):
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
tenant_id: str
dataset_ids: list[str]
top_k: int = 2
score_threshold: Optional[float] = None
reranking_provider_name: str
reranking_model_name: str
return_resource: bool
retriever_from: str
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler]):
with flask_app.app_context():

View File

@ -0,0 +1,34 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from pydantic import BaseModel
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
class Config:
arbitrary_types_allowed = True
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@ -1,10 +1,8 @@
from typing import Optional
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool):
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
tenant_id: str
dataset_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@ -154,6 +146,3 @@ class DatasetRetrieverTool(BaseTool):
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()

View File

@ -1,7 +1,5 @@
from typing import Any
from langchain.tools import BaseTool
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
@ -14,11 +12,12 @@ from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.tool import Tool
class DatasetRetrieverTool(Tool):
langchain_tool: BaseTool
retrival_tool: DatasetRetrieverBaseTool
@staticmethod
def get_dataset_tools(tenant_id: str,
@ -43,7 +42,7 @@ class DatasetRetrieverTool(Tool):
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
langchain_tools = feature.to_dataset_retriever_tool(
retrival_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
@ -54,17 +53,17 @@ class DatasetRetrieverTool(Tool):
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools
# convert retrival tools to Tools
tools = []
for langchain_tool in langchain_tools:
for retrival_tool in retrival_tools:
tool = DatasetRetrieverTool(
langchain_tool=langchain_tool,
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
retrival_tool=retrival_tool,
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''),
llm=langchain_tool.description),
llm=retrival_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
@ -96,7 +95,7 @@ class DatasetRetrieverTool(Tool):
return self.create_text_message(text='please input query')
# invoke dataset retriever tool
result = self.langchain_tool._run(query=query)
result = self.retrival_tool._run(query=query)
return self.create_text_message(text=result)