feat: fix dataset retrieve agent llm not support error (#656)

This commit is contained in:
John Wang 2023-07-27 15:45:52 +08:00 committed by GitHub
parent ae7c0380dc
commit ba3dc8cae0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 2 deletions

View File

@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
),
**kwargs: Any,
) -> BaseSingleActionAgent:
llm.model_name = 'gpt-3.5-turbo'
return super().from_llm_and_tools(
llm=llm,
tools=tools,

View File

@ -31,6 +31,7 @@ class AgentConfiguration(BaseModel):
llm: BaseLanguageModel
tools: list[BaseTool]
summary_llm: BaseLanguageModel
dataset_llm: BaseLanguageModel
memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None
max_iterations: int = 6
@ -84,7 +85,7 @@ class AgentExecutor:
elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.llm,
llm=self.configuration.dataset_llm,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True

View File

@ -32,6 +32,7 @@ class OrchestratorRuleParser:
self.tenant_id = tenant_id
self.app_model_config = app_model_config
self.agent_summary_model_name = "gpt-3.5-turbo-16k"
self.dataset_retrieve_model_name = "gpt-3.5-turbo"
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
@ -89,11 +90,20 @@ class OrchestratorRuleParser:
if len(tools) == 0:
return None
dataset_llm = LLMBuilder.to_llm(
tenant_id=self.tenant_id,
model_name=self.dataset_retrieve_model_name,
temperature=0,
max_tokens=500,
callbacks=[DifyStdOutCallbackHandler()]
)
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
llm=agent_llm,
tools=tools,
summary_llm=summary_llm,
dataset_llm=dataset_llm,
memory=memory,
callbacks=[chain_callback, agent_callback],
max_iterations=10,