mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 05:09:03 +08:00
Refactor: CoT runner (#2157)
This commit is contained in:
parent
c8fb619d37
commit
48d5628fd4
@ -19,8 +19,6 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
|||||||
|
|
||||||
from models.model import Conversation, Message
|
from models.model import Conversation, Message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||||
def run(self, model_instance: ModelInstance,
|
def run(self, model_instance: ModelInstance,
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
@ -93,6 +91,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
prompt_messages_tools = []
|
prompt_messages_tools = []
|
||||||
|
|
||||||
message_file_ids = []
|
message_file_ids = []
|
||||||
|
|
||||||
agent_thought = self.create_agent_thought(
|
agent_thought = self.create_agent_thought(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
message='',
|
message='',
|
||||||
@ -100,6 +99,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
tool_input='',
|
tool_input='',
|
||||||
messages_ids=message_file_ids
|
messages_ids=message_file_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if iteration_step > 1:
|
||||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|
||||||
# update prompt messages
|
# update prompt messages
|
||||||
@ -138,6 +139,10 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
if llm_result.usage:
|
if llm_result.usage:
|
||||||
increse_usage(llm_usage, llm_result.usage)
|
increse_usage(llm_usage, llm_result.usage)
|
||||||
|
|
||||||
|
# publish agent thought if it's first iteration
|
||||||
|
if iteration_step == 1:
|
||||||
|
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|
||||||
self.save_agent_thought(agent_thought=agent_thought,
|
self.save_agent_thought(agent_thought=agent_thought,
|
||||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||||
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
tool_input=scratchpad.action.action_input if scratchpad.action else '',
|
||||||
@ -187,7 +192,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
tool_call_args = scratchpad.action.action_input
|
tool_call_args = scratchpad.action.action_input
|
||||||
tool_instance = tool_instances.get(tool_call_name)
|
tool_instance = tool_instances.get(tool_call_name)
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
logger.error(f"failed to find tool instance: {tool_call_name}")
|
|
||||||
answer = f"there is not a tool named {tool_call_name}"
|
answer = f"there is not a tool named {tool_call_name}"
|
||||||
self.save_agent_thought(agent_thought=agent_thought,
|
self.save_agent_thought(agent_thought=agent_thought,
|
||||||
tool_name='',
|
tool_name='',
|
||||||
@ -237,7 +241,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
|
|
||||||
if error_response:
|
if error_response:
|
||||||
observation = error_response
|
observation = error_response
|
||||||
logger.error(error_response)
|
|
||||||
else:
|
else:
|
||||||
observation = self._convert_tool_response_to_str(tool_response)
|
observation = self._convert_tool_response_to_str(tool_response)
|
||||||
|
|
||||||
@ -543,13 +546,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||||||
# add assistant message
|
# add assistant message
|
||||||
if len(agent_scratchpad) > 0:
|
if len(agent_scratchpad) > 0:
|
||||||
prompt_messages.append(AssistantPromptMessage(
|
prompt_messages.append(AssistantPromptMessage(
|
||||||
content=(agent_scratchpad[-1].thought or '') + "\n" + (agent_scratchpad[-1].observation or '')
|
content=(agent_scratchpad[-1].thought or '')
|
||||||
))
|
))
|
||||||
|
|
||||||
# add user message
|
# add user message
|
||||||
if len(agent_scratchpad) > 0:
|
if len(agent_scratchpad) > 0:
|
||||||
prompt_messages.append(UserPromptMessage(
|
prompt_messages.append(UserPromptMessage(
|
||||||
content=input,
|
content=(agent_scratchpad[-1].observation or ''),
|
||||||
))
|
))
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
@ -2,13 +2,13 @@ identity:
|
|||||||
author: Dify
|
author: Dify
|
||||||
name: youtube
|
name: youtube
|
||||||
label:
|
label:
|
||||||
en_US: Youtube
|
en_US: YouTube
|
||||||
zh_Hans: Youtube
|
zh_Hans: YouTube
|
||||||
pt_BR: Youtube
|
pt_BR: YouTube
|
||||||
description:
|
description:
|
||||||
en_US: Youtube
|
en_US: YouTube
|
||||||
zh_Hans: Youtube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。
|
zh_Hans: YouTube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。
|
||||||
pt_BR: Youtube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos.
|
pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos.
|
||||||
icon: icon.png
|
icon: icon.png
|
||||||
credentials_for_provider:
|
credentials_for_provider:
|
||||||
google_api_key:
|
google_api_key:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user