mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:39:04 +08:00
feat: optimize completion model agent (#1364)
This commit is contained in:
parent
16d80ebab3
commit
07285e5f8b
@ -76,7 +76,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@ -12,6 +12,7 @@ from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@ -92,6 +93,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
@ -107,6 +112,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
elif isinstance(tool_inputs, str):
|
||||
agent_decision.tool_input = kwargs['input']
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
@ -143,6 +150,61 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
@ -160,15 +222,23 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@ -15,6 +15,7 @@ from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
@ -184,6 +185,61 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
@ -201,15 +257,23 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
|
Loading…
x
Reference in New Issue
Block a user