From ba3dc8cae0ecba92bd15139346c5814f2fbd6edb Mon Sep 17 00:00:00 2001 From: John Wang Date: Thu, 27 Jul 2023 15:45:52 +0800 Subject: [PATCH] feat: fix dataset retrieve agent llm not support error (#656) --- api/core/agent/agent/multi_dataset_router_agent.py | 1 - api/core/agent/agent_executor.py | 3 ++- api/core/orchestrator_rule_parser.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 9f148159d5..34dacaee3d 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ), **kwargs: Any, ) -> BaseSingleActionAgent: - llm.model_name = 'gpt-3.5-turbo' return super().from_llm_and_tools( llm=llm, tools=tools, diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index dbf4db0386..da36533fd2 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -31,6 +31,7 @@ class AgentConfiguration(BaseModel): llm: BaseLanguageModel tools: list[BaseTool] summary_llm: BaseLanguageModel + dataset_llm: BaseLanguageModel memory: Optional[BaseChatMemory] = None callbacks: Callbacks = None max_iterations: int = 6 @@ -84,7 +85,7 @@ class AgentExecutor: elif self.configuration.strategy == PlanningStrategy.ROUTER: self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] agent = MultiDatasetRouterAgent.from_llm_and_tools( - llm=self.configuration.llm, + llm=self.configuration.dataset_llm, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, verbose=True diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 8a57ee2f3c..971f7ffed5 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -32,6 +32,7 @@ class OrchestratorRuleParser: self.tenant_id = tenant_id self.app_model_config = app_model_config self.agent_summary_model_name = "gpt-3.5-turbo-16k" + self.dataset_retrieve_model_name = "gpt-3.5-turbo" def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ @@ -89,11 +90,20 @@ class OrchestratorRuleParser: if len(tools) == 0: return None + dataset_llm = LLMBuilder.to_llm( + tenant_id=self.tenant_id, + model_name=self.dataset_retrieve_model_name, + temperature=0, + max_tokens=500, + callbacks=[DifyStdOutCallbackHandler()] + ) + agent_configuration = AgentConfiguration( strategy=planning_strategy, llm=agent_llm, tools=tools, summary_llm=summary_llm, + dataset_llm=dataset_llm, memory=memory, callbacks=[chain_callback, agent_callback], max_iterations=10,