diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 44db0d4a33..9866db12f6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -98,6 +98,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ) self._stream_generate_routes = self._get_stream_generate_routes() + self._conversation_name_generate_thread = None def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ @@ -108,6 +109,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc db.session.refresh(self._user) db.session.close() + # start generate conversation name thread + self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation, + self._application_generate_entity.query + ) + generator = self._process_stream_response() if self._stream: return self._to_stream_response(generator) @@ -278,6 +285,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc else: continue + if self._conversation_name_generate_thread: + self._conversation_name_generate_thread.join() + def _save_message(self) -> None: """ Save message. diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 4fc9d6abaa..a7dbb4754c 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -97,6 +97,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) ) + self._conversation_name_generate_thread = None + def process(self) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, @@ -110,6 +112,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan db.session.refresh(self._message) db.session.close() + if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: + # start generate conversation name thread + self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation, + self._application_generate_entity.query + ) + generator = self._process_stream_response() if self._stream: return self._to_stream_response(generator) @@ -256,6 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan else: continue + if self._conversation_name_generate_thread: + self._conversation_name_generate_thread.join() + def _save_message(self) -> None: """ Save message. diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 16eb3d4fc2..2848455278 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -1,5 +1,8 @@ +from threading import Thread from typing import Optional, Union +from flask import Flask, current_app + from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -19,9 +22,10 @@ from core.app.entities.task_entities import ( MessageReplaceStreamResponse, MessageStreamResponse, ) +from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db -from models.model import MessageAnnotation, MessageFile +from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -34,6 +38,59 @@ class MessageCycleManage: ] _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + """ + Generate conversation name. + :param conversation: conversation + :param query: query + :return: thread + """ + is_first_message = self._application_generate_entity.conversation_id is None + extras = self._application_generate_entity.extras + auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + + if auto_generate_conversation_name and is_first_message: + # start generate thread + thread = Thread(target=self._generate_conversation_name_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'conversation_id': conversation.id, + 'query': query + }) + + thread.start() + + return thread + + return None + + def _generate_conversation_name_worker(self, + flask_app: Flask, + conversation_id: str, + query: str): + with flask_app.app_context(): + # get conversation and message + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id) + .first() + ) + + if conversation.mode != AppMode.COMPLETION.value: + app_model = conversation.app + if not app_model: + return + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) + conversation.name = name + except: + pass + + db.session.merge(conversation) + db.session.commit() + db.session.close() + def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: """ Handle annotation reply. diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index e0f3b84990..9a7c0deb20 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -5,7 +5,6 @@ from .create_installed_app_when_app_created import handle from .create_site_record_when_app_created import handle from .deduct_quota_when_messaeg_created import handle from .delete_installed_app_when_app_deleted import handle -from .generate_conversation_name_when_first_message_created import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_provider_last_used_at_when_messaeg_created import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py deleted file mode 100644 index 31535bf4ef..0000000000 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ /dev/null @@ -1,32 +0,0 @@ -from core.llm_generator.llm_generator import LLMGenerator -from events.message_event import message_was_created -from extensions.ext_database import db -from models.model import AppMode - - -@message_was_created.connect -def handle(sender, **kwargs): - message = sender - conversation = kwargs.get('conversation') - is_first_message = kwargs.get('is_first_message') - extras = kwargs.get('extras', {}) - - auto_generate_conversation_name = True - if extras: - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) - - if auto_generate_conversation_name and is_first_message: - if conversation.mode != AppMode.COMPLETION.value: - app_model = conversation.app - if not app_model: - return - - # generate conversation name - try: - name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) - conversation.name = name - except: - pass - - db.session.merge(conversation) - db.session.commit()