From f073dca22aac94cb48c509c32f999ca880338a94 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 10 Mar 2024 15:48:31 +0800 Subject: [PATCH] feat: optimize db connection when llm invoking (#2774) --- api/core/app_runner/assistant_app_runner.py | 4 ++++ api/core/app_runner/basic_app_runner.py | 2 ++ api/core/app_runner/generate_task_pipeline.py | 8 +++++++ api/core/application_manager.py | 6 ++--- api/core/features/assistant_base_runner.py | 22 +++++++++++++++++-- 5 files changed, 37 insertions(+), 5 deletions(-) diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index d9a3447bda..655a5a1c7c 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING + db.session.refresh(conversation) + db.session.refresh(message) + db.session.close() + # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: assistant_cot_runner = AssistantCotApplicationRunner( diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 83f4f6929a..d3c91337c8 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner): model=app_orchestration_config.model_config.model ) + db.session.close() + invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_orchestration_config.model_config.parameters, diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 5fd635bc3b..1cc56483ad 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -89,6 +89,10 @@ class GenerateTaskPipeline: Process generate task pipeline. :return: """ + db.session.refresh(self._conversation) + db.session.refresh(self._message) + db.session.close() + if stream: return self._process_stream_response() else: @@ -303,6 +307,7 @@ class GenerateTaskPipeline: .first() ) db.session.refresh(agent_thought) + db.session.close() if agent_thought: response = { @@ -330,6 +335,8 @@ class GenerateTaskPipeline: .filter(MessageFile.id == event.message_file_id) .first() ) + db.session.close() + # get extension if '.' in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' @@ -413,6 +420,7 @@ class GenerateTaskPipeline: usage = llm_result.usage self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message_tokens = usage.prompt_tokens diff --git a/api/core/application_manager.py b/api/core/application_manager.py index e073eac4b9..9aca61c7bb 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -201,7 +201,7 @@ class ApplicationManager: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: ApplicationQueueManager, @@ -233,8 +233,6 @@ class ApplicationManager: else: logger.exception(e) raise e - finally: - db.session.remove() def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ -> AppOrchestrationConfigEntity: @@ -651,6 +649,7 @@ class ApplicationManager: db.session.add(conversation) db.session.commit() + db.session.refresh(conversation) else: conversation = ( db.session.query(Conversation) @@ -689,6 +688,7 @@ class ApplicationManager: db.session.add(message) db.session.commit() + db.session.refresh(message) for file in application_generate_entity.files: message_file = MessageFile( diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 0ee6436d11..1d9541070f 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner): self.agent_thought_count = db.session.query(MessageAgentThought).filter( MessageAgentThought.message_id == self.message.id, ).count() + db.session.close() # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) @@ -341,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner): created_by=self.user_id, ) db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + result.append(( message_file, message.save_as )) - - db.session.commit() + db.session.close() + return result def create_agent_thought(self, message_id: str, message: str, @@ -384,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner): db.session.add(thought) db.session.commit() + db.session.refresh(thought) + db.session.close() self.agent_thought_count += 1 @@ -401,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner): """ Save agent thought """ + agent_thought = db.session.query(MessageAgentThought).filter( + MessageAgentThought.id == agent_thought.id + ).first() + if thought is not None: agent_thought.thought = thought @@ -451,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner): agent_thought.tool_labels_str = json.dumps(labels) db.session.commit() + db.session.close() def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ @@ -523,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner): """ convert tool variables to db variables """ + db_variables = db.session.query(ToolConversationVariables).filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ).first() + db_variables.updated_at = datetime.utcnow() db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() + db.session.close() def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ @@ -581,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner): if message.answer: result.append(AssistantPromptMessage(content=message.answer)) + db.session.close() + return result \ No newline at end of file