mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:29:01 +08:00
feat: optimize completion model agent (#1364)
This commit is contained in:
parent
16d80ebab3
commit
07285e5f8b
@ -76,7 +76,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|||||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||||
if isinstance(agent_decision, AgentAction):
|
if isinstance(agent_decision, AgentAction):
|
||||||
tool_inputs = agent_decision.tool_input
|
tool_inputs = agent_decision.tool_input
|
||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||||
tool_inputs['query'] = kwargs['input']
|
tool_inputs['query'] = kwargs['input']
|
||||||
agent_decision.tool_input = tool_inputs
|
agent_decision.tool_input = tool_inputs
|
||||||
else:
|
else:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||||
|
|
||||||
from langchain import BasePromptTemplate
|
from langchain import BasePromptTemplate, PromptTemplate
|
||||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -12,6 +12,7 @@ from langchain.tools import BaseTool
|
|||||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||||
|
|
||||||
from core.chain.llm_chain import LLMChain
|
from core.chain.llm_chain import LLMChain
|
||||||
|
from core.model_providers.models.entity.model_params import ModelMode
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
@ -92,6 +93,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||||
|
|
||||||
|
if intermediate_steps:
|
||||||
|
_, observation = intermediate_steps[-1]
|
||||||
|
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||||
|
|
||||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -107,6 +112,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||||
tool_inputs['query'] = kwargs['input']
|
tool_inputs['query'] = kwargs['input']
|
||||||
agent_decision.tool_input = tool_inputs
|
agent_decision.tool_input = tool_inputs
|
||||||
|
elif isinstance(tool_inputs, str):
|
||||||
|
agent_decision.tool_input = kwargs['input']
|
||||||
else:
|
else:
|
||||||
agent_decision.return_values['output'] = ''
|
agent_decision.return_values['output'] = ''
|
||||||
return agent_decision
|
return agent_decision
|
||||||
@ -143,6 +150,61 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
]
|
]
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_completion_prompt(
|
||||||
|
cls,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
prefix: str = PREFIX,
|
||||||
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
) -> PromptTemplate:
|
||||||
|
"""Create prompt in the style of the zero shot agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tools the agent will have access to, used to format the
|
||||||
|
prompt.
|
||||||
|
prefix: String to put before the list of tools.
|
||||||
|
input_variables: List of input variables the final prompt will expect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PromptTemplate with the template assembled from the pieces here.
|
||||||
|
"""
|
||||||
|
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||||
|
Question: {input}
|
||||||
|
Thought: {agent_scratchpad}
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||||
|
tool_names = ", ".join([tool.name for tool in tools])
|
||||||
|
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||||
|
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
|
return PromptTemplate(template=template, input_variables=input_variables)
|
||||||
|
|
||||||
|
def _construct_scratchpad(
|
||||||
|
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||||
|
) -> str:
|
||||||
|
agent_scratchpad = ""
|
||||||
|
for action, observation in intermediate_steps:
|
||||||
|
agent_scratchpad += action.log
|
||||||
|
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||||
|
|
||||||
|
if not isinstance(agent_scratchpad, str):
|
||||||
|
raise ValueError("agent_scratchpad should be of type string.")
|
||||||
|
if agent_scratchpad:
|
||||||
|
llm_chain = cast(LLMChain, self.llm_chain)
|
||||||
|
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||||
|
return (
|
||||||
|
f"This was your previous work "
|
||||||
|
f"(but I haven't seen any of it! I only see what "
|
||||||
|
f"you return as final answer):\n{agent_scratchpad}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return agent_scratchpad
|
||||||
|
else:
|
||||||
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
@ -160,15 +222,23 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
prompt = cls.create_prompt(
|
if model_instance.model_mode == ModelMode.CHAT:
|
||||||
tools,
|
prompt = cls.create_prompt(
|
||||||
prefix=prefix,
|
tools,
|
||||||
suffix=suffix,
|
prefix=prefix,
|
||||||
human_message_template=human_message_template,
|
suffix=suffix,
|
||||||
format_instructions=format_instructions,
|
human_message_template=human_message_template,
|
||||||
input_variables=input_variables,
|
format_instructions=format_instructions,
|
||||||
memory_prompts=memory_prompts,
|
input_variables=input_variables,
|
||||||
)
|
memory_prompts=memory_prompts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = cls.create_completion_prompt(
|
||||||
|
tools,
|
||||||
|
prefix=prefix,
|
||||||
|
format_instructions=format_instructions,
|
||||||
|
input_variables=input_variables
|
||||||
|
)
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||||
|
|
||||||
from langchain import BasePromptTemplate
|
from langchain import BasePromptTemplate, PromptTemplate
|
||||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -15,6 +15,7 @@ from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
|||||||
|
|
||||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||||
from core.chain.llm_chain import LLMChain
|
from core.chain.llm_chain import LLMChain
|
||||||
|
from core.model_providers.models.entity.model_params import ModelMode
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
@ -184,6 +185,61 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
]
|
]
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_completion_prompt(
|
||||||
|
cls,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
prefix: str = PREFIX,
|
||||||
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
) -> PromptTemplate:
|
||||||
|
"""Create prompt in the style of the zero shot agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tools the agent will have access to, used to format the
|
||||||
|
prompt.
|
||||||
|
prefix: String to put before the list of tools.
|
||||||
|
input_variables: List of input variables the final prompt will expect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PromptTemplate with the template assembled from the pieces here.
|
||||||
|
"""
|
||||||
|
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||||
|
Question: {input}
|
||||||
|
Thought: {agent_scratchpad}
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||||
|
tool_names = ", ".join([tool.name for tool in tools])
|
||||||
|
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||||
|
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
|
return PromptTemplate(template=template, input_variables=input_variables)
|
||||||
|
|
||||||
|
def _construct_scratchpad(
|
||||||
|
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||||
|
) -> str:
|
||||||
|
agent_scratchpad = ""
|
||||||
|
for action, observation in intermediate_steps:
|
||||||
|
agent_scratchpad += action.log
|
||||||
|
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||||
|
|
||||||
|
if not isinstance(agent_scratchpad, str):
|
||||||
|
raise ValueError("agent_scratchpad should be of type string.")
|
||||||
|
if agent_scratchpad:
|
||||||
|
llm_chain = cast(LLMChain, self.llm_chain)
|
||||||
|
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||||
|
return (
|
||||||
|
f"This was your previous work "
|
||||||
|
f"(but I haven't seen any of it! I only see what "
|
||||||
|
f"you return as final answer):\n{agent_scratchpad}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return agent_scratchpad
|
||||||
|
else:
|
||||||
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(
|
def from_llm_and_tools(
|
||||||
cls,
|
cls,
|
||||||
@ -201,15 +257,23 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
|||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
prompt = cls.create_prompt(
|
if model_instance.model_mode == ModelMode.CHAT:
|
||||||
tools,
|
prompt = cls.create_prompt(
|
||||||
prefix=prefix,
|
tools,
|
||||||
suffix=suffix,
|
prefix=prefix,
|
||||||
human_message_template=human_message_template,
|
suffix=suffix,
|
||||||
format_instructions=format_instructions,
|
human_message_template=human_message_template,
|
||||||
input_variables=input_variables,
|
format_instructions=format_instructions,
|
||||||
memory_prompts=memory_prompts,
|
input_variables=input_variables,
|
||||||
)
|
memory_prompts=memory_prompts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = cls.create_completion_prompt(
|
||||||
|
tools,
|
||||||
|
prefix=prefix,
|
||||||
|
format_instructions=format_instructions,
|
||||||
|
input_variables=input_variables,
|
||||||
|
)
|
||||||
llm_chain = LLMChain(
|
llm_chain = LLMChain(
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user