mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 03:35:51 +08:00
chore: remove Langchain tools import (#3407)
This commit is contained in:
parent
c227f3d985
commit
0737e930cb
@ -159,7 +159,7 @@ class BlobLoader(ABC):
|
||||
def yield_blobs(
|
||||
self,
|
||||
) -> Iterable[Blob]:
|
||||
"""A lazy loader for raw data represented by LangChain's Blob object.
|
||||
"""A lazy loader for raw data represented by Blob object.
|
||||
|
||||
Returns:
|
||||
A generator over blobs
|
||||
|
@ -2,7 +2,6 @@ import threading
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
@ -19,6 +18,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
@ -383,7 +383,7 @@ class DatasetRetrieval:
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[list[BaseTool]]:
|
||||
-> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
|
25
api/core/rag/retrieval/output_parser/react_output.py
Normal file
25
api/core/rag/retrieval/output_parser/react_output.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReactAction:
|
||||
"""A full description of an action for an ReactAction to execute."""
|
||||
|
||||
tool: str
|
||||
"""The name of the Tool to execute."""
|
||||
tool_input: Union[str, dict]
|
||||
"""The input to pass in to the Tool."""
|
||||
log: str
|
||||
"""Additional information to log about the action."""
|
||||
|
||||
|
||||
class ReactFinish(NamedTuple):
|
||||
"""The final return value of an ReactFinish."""
|
||||
|
||||
return_values: dict
|
||||
"""Dictionary of return values."""
|
||||
log: str
|
||||
"""Additional information to log about the return value"""
|
@ -2,28 +2,24 @@ import json
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser
|
||||
from langchain.agents.structured_chat.output_parser import logger
|
||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish
|
||||
|
||||
|
||||
class StructuredChatOutputParser(LCStructuredChatOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
class StructuredChatOutputParser:
|
||||
def parse(self, text: str) -> Union[ReactAction, ReactFinish]:
|
||||
try:
|
||||
action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
|
||||
if action_match is not None:
|
||||
response = json.loads(action_match.group(2).strip(), strict=False)
|
||||
if isinstance(response, list):
|
||||
# gpt turbo frequently ignores the directive to emit a single action
|
||||
logger.warning("Got multiple action responses: %s", response)
|
||||
response = response[0]
|
||||
if response["action"] == "Final Answer":
|
||||
return AgentFinish({"output": response["action_input"]}, text)
|
||||
return ReactFinish({"output": response["action_input"]}, text)
|
||||
else:
|
||||
return AgentAction(
|
||||
return ReactAction(
|
||||
response["action"], response.get("action_input", {}), text
|
||||
)
|
||||
else:
|
||||
return AgentFinish({"output": text}, text)
|
||||
return ReactFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Could not parse LLM output: {text}")
|
||||
raise ValueError(f"Could not parse LLM output: {text}")
|
||||
|
@ -1,20 +1,21 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.schema import AgentAction
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Thought:"""
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
@ -86,7 +87,6 @@ class ReactMultiDatasetRouter:
|
||||
tenant_id: str,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
if model_config.mode == "chat":
|
||||
@ -95,7 +95,6 @@ class ReactMultiDatasetRouter:
|
||||
tools=tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
)
|
||||
else:
|
||||
@ -103,7 +102,6 @@ class ReactMultiDatasetRouter:
|
||||
tools=tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=None
|
||||
)
|
||||
stop = ['Observation:']
|
||||
# handle invoke result
|
||||
@ -127,9 +125,9 @@ class ReactMultiDatasetRouter:
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
output_parser = StructuredChatOutputParser()
|
||||
agent_decision = output_parser.parse(result_text)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
return agent_decision.tool
|
||||
react_decision = output_parser.parse(result_text)
|
||||
if isinstance(react_decision, ReactAction):
|
||||
return react_decision.tool
|
||||
return None
|
||||
|
||||
def _invoke_llm(self, completion_param: dict,
|
||||
@ -139,7 +137,6 @@ class ReactMultiDatasetRouter:
|
||||
) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data: node data
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
@ -197,7 +194,6 @@ class ReactMultiDatasetRouter:
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> list[ChatModelMessage]:
|
||||
tool_strings = []
|
||||
@ -227,16 +223,13 @@ class ReactMultiDatasetRouter:
|
||||
tools: Sequence[PromptMessageTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
) -> CompletionModelPromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
@ -249,6 +242,4 @@ Thought: {agent_scratchpad}
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
return CompletionModelPromptTemplate(text=template)
|
||||
|
@ -1,8 +1,6 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
@ -10,6 +8,7 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
@ -29,20 +28,15 @@ class DatasetMultiRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||
|
||||
|
||||
class DatasetMultiRetrieverTool(BaseTool):
|
||||
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying multi dataset."""
|
||||
name: str = "dataset_"
|
||||
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||
description: str = "dataset multi retriever and rerank. "
|
||||
tenant_id: str
|
||||
dataset_ids: list[str]
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
|
||||
@ -149,9 +143,6 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list,
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler]):
|
||||
with flask_app.app_context():
|
||||
|
@ -0,0 +1,34 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
|
||||
class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool.
|
||||
|
||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
to child implementations to enable tracing,
|
||||
"""
|
@ -1,10 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
@ -24,19 +22,13 @@ class DatasetRetrieverToolInput(BaseModel):
|
||||
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
|
||||
|
||||
|
||||
class DatasetRetrieverTool(BaseTool):
|
||||
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
"""Tool for querying a Dataset."""
|
||||
name: str = "dataset"
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
@ -154,6 +146,3 @@ class DatasetRetrieverTool(BaseTool):
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
@ -1,7 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
@ -14,11 +12,12 @@ from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
langchain_tool: BaseTool
|
||||
retrival_tool: DatasetRetrieverBaseTool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(tenant_id: str,
|
||||
@ -43,7 +42,7 @@ class DatasetRetrieverTool(Tool):
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
langchain_tools = feature.to_dataset_retriever_tool(
|
||||
retrival_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
@ -54,17 +53,17 @@ class DatasetRetrieverTool(Tool):
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert langchain tools to Tools
|
||||
# convert retrival tools to Tools
|
||||
tools = []
|
||||
for langchain_tool in langchain_tools:
|
||||
for retrival_tool in retrival_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
langchain_tool=langchain_tool,
|
||||
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
retrival_tool=retrival_tool,
|
||||
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US='', zh_Hans=''),
|
||||
llm=langchain_tool.description),
|
||||
llm=retrival_tool.description),
|
||||
runtime=DatasetRetrieverTool.Runtime()
|
||||
)
|
||||
|
||||
@ -96,7 +95,7 @@ class DatasetRetrieverTool(Tool):
|
||||
return self.create_text_message(text='please input query')
|
||||
|
||||
# invoke dataset retriever tool
|
||||
result = self.langchain_tool._run(query=query)
|
||||
result = self.retrival_tool._run(query=query)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user