From 1d9cc5ca056d23cbaa1511b9e929b016f03d4d60 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 18 Aug 2023 16:20:42 +0800 Subject: [PATCH] fix: universal chat when default model invalid (#905) --- .../openai_function_call_summarize_mixin.py | 2 +- api/core/agent/agent/structured_chat.py | 2 +- api/core/agent/agent_executor.py | 2 +- api/core/model_providers/model_factory.py | 6 ++- api/core/orchestrator_rule_parser.py | 44 +++++++++++-------- api/core/tool/current_datetime_tool.py | 25 +++++++++++ api/core/tool/web_reader_tool.py | 4 +- api/libs/helper.py | 6 --- 8 files changed, 59 insertions(+), 32 deletions(-) create mode 100644 api/core/tool/current_datetime_tool.py diff --git a/api/core/agent/agent/openai_function_call_summarize_mixin.py b/api/core/agent/agent/openai_function_call_summarize_mixin.py index 6dcda1200a..f436346e24 100644 --- a/api/core/agent/agent/openai_function_call_summarize_mixin.py +++ b/api/core/agent/agent/openai_function_call_summarize_mixin.py @@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): moving_summary_buffer: str = "" moving_summary_index: int = 0 - summary_llm: BaseLanguageModel + summary_llm: BaseLanguageModel = None model_instance: BaseLLM class Config: diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index c16214c632..9d4f4d608c 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -52,7 +52,7 @@ Action: class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): moving_summary_buffer: str = "" moving_summary_index: int = 0 - summary_llm: BaseLanguageModel + summary_llm: BaseLanguageModel = None model_instance: BaseLLM class Config: diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index ada5ab49a8..17d8ecf6bd 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -32,7 +32,7 @@ class AgentConfiguration(BaseModel): strategy: PlanningStrategy model_instance: BaseLLM tools: list[BaseTool] - summary_model_instance: BaseLLM + summary_model_instance: BaseLLM = None memory: Optional[BaseChatMemory] = None callbacks: Callbacks = None max_iterations: int = 6 diff --git a/api/core/model_providers/model_factory.py b/api/core/model_providers/model_factory.py index b76a640256..ae5951457d 100644 --- a/api/core/model_providers/model_factory.py +++ b/api/core/model_providers/model_factory.py @@ -46,7 +46,8 @@ class ModelFactory: model_name: Optional[str] = None, model_kwargs: Optional[ModelKwargs] = None, streaming: bool = False, - callbacks: Callbacks = None) -> Optional[BaseLLM]: + callbacks: Callbacks = None, + deduct_quota: bool = True) -> Optional[BaseLLM]: """ get text generation model. @@ -56,6 +57,7 @@ class ModelFactory: :param model_kwargs: :param streaming: :param callbacks: + :param deduct_quota: :return: """ is_default_model = False @@ -95,7 +97,7 @@ class ModelFactory: else: raise e - if is_default_model: + if is_default_model or not deduct_quota: model_instance.deduct_quota = False return model_instance diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index f4eed96ff5..4f87e45e0d 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -17,12 +17,13 @@ from core.conversation_message_task import ConversationMessageTask from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode +from core.model_providers.models.llm.base import BaseLLM +from core.tool.current_datetime_tool import DatetimeTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.provider.serpapi_provider import SerpAPIToolProvider from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput from core.tool.web_reader_tool import WebReaderTool from extensions.ext_database import db -from libs import helper from models.dataset import Dataset, DatasetProcessRule from models.model import AppModelConfig @@ -82,15 +83,19 @@ class OrchestratorRuleParser: try: summary_model_instance = ModelFactory.get_text_generation_model( tenant_id=self.tenant_id, + model_provider_name=agent_provider_name, + model_name=agent_model_name, model_kwargs=ModelKwargs( temperature=0, max_tokens=500 - ) + ), + deduct_quota=False ) except ProviderTokenNotInitError as e: summary_model_instance = None tools = self.to_tools( + agent_model_instance=agent_model_instance, tool_configs=tool_configs, conversation_message_task=conversation_message_task, rest_tokens=rest_tokens, @@ -140,11 +145,12 @@ class OrchestratorRuleParser: return None - def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask, + def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: """ Convert app agent tool configs to tools + :param agent_model_instance: :param rest_tokens: :param tool_configs: app agent tool configs :param conversation_message_task: @@ -162,7 +168,7 @@ class OrchestratorRuleParser: if tool_type == "dataset": tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) elif tool_type == "web_reader": - tool = self.to_web_reader_tool() + tool = self.to_web_reader_tool(agent_model_instance) elif tool_type == "google_search": tool = self.to_google_search_tool() elif tool_type == "wikipedia": @@ -207,24 +213,28 @@ class OrchestratorRuleParser: return tool - def to_web_reader_tool(self) -> Optional[BaseTool]: + def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]: """ A tool for reading web pages :return: """ - summary_model_instance = ModelFactory.get_text_generation_model( - tenant_id=self.tenant_id, - model_kwargs=ModelKwargs( - temperature=0, - max_tokens=500 + try: + summary_model_instance = ModelFactory.get_text_generation_model( + tenant_id=self.tenant_id, + model_provider_name=agent_model_instance.model_provider.provider_name, + model_name=agent_model_instance.name, + model_kwargs=ModelKwargs( + temperature=0, + max_tokens=500 + ), + deduct_quota=False ) - ) - - summary_llm = summary_model_instance.client + except ProviderTokenNotInitError: + summary_model_instance = None tool = WebReaderTool( - llm=summary_llm, + llm=summary_model_instance.client if summary_model_instance else None, max_chunk_length=4000, continue_reading=True, callbacks=[DifyStdOutCallbackHandler()] @@ -252,11 +262,7 @@ class OrchestratorRuleParser: return tool def to_current_datetime_tool(self) -> Optional[BaseTool]: - tool = Tool( - name="current_datetime", - description="A tool when you want to get the current date, time, week, month or year, " - "and the time zone is UTC. Result is \"