diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index efa6f55956..4fe22821f3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,6 +29,7 @@ from factories import file_factory from models.account import Account from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow +from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -105,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation = None conversation_id = args.get("conversation_id") if conversation_id: - conversation = self._get_conversation_by_user( + conversation = ConversationService.get_conversation( app_model=app_model, conversation_id=conversation_id, user=user ) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index e47428b557..23abe41080 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -24,6 +24,7 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser +from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -98,9 +99,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) - + conversation_id = args.get("conversation_id") + if conversation_id: + conversation = ConversationService.get_conversation( + app_model=app_model, conversation_id=conversation_id, user=user + ) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 213fb79b7d..5fc9ed55af 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -24,6 +24,7 @@ from extensions.ext_database import db from factories import file_factory from models.account import Account from models.model import App, EndUser +from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -91,9 +92,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) - + conversation_id = args.get("conversation_id") + if conversation_id: + conversation = ConversationService.get_conversation( + app_model=app_model, conversation_id=conversation_id, user=user + ) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 909f76cc9a..efaa7b6756 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -4,8 +4,6 @@ from collections.abc import Generator from datetime import UTC, datetime from typing import Optional, Union, cast -from sqlalchemy import and_ - from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError @@ -30,7 +28,7 @@ from models import Account from models.enums import CreatedByRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError -from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError +from services.errors.conversation import ConversationNotExistsError logger = logging.getLogger(__name__) @@ -81,31 +79,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") raise e - def _get_conversation_by_user( - self, app_model: App, conversation_id: str, user: Union[Account, EndUser] - ) -> Conversation: - conversation_filter = [ - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.status == "normal", - Conversation.is_deleted.is_(False), - ] - - if isinstance(user, Account): - conversation_filter.append(Conversation.from_account_id == user.id) - else: - conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) - - conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() - - if not conversation: - raise ConversationNotExistsError() - - if conversation.status != "normal": - raise ConversationCompletedError() - - return conversation - def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: app_model_config = ( diff --git a/api/services/message_service.py b/api/services/message_service.py index 480d038623..aefab1556c 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -15,7 +15,6 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.account import Account from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback from services.conversation_service import ConversationService -from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, LastMessageNotExistsError, @@ -210,12 +209,6 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - if not conversation: - raise ConversationNotExistsError() - - if conversation.status != "normal": - raise ConversationCompletedError() - model_manager = ModelManager() if app_model.mode == AppMode.ADVANCED_CHAT.value: