diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 0df26637a3..81477533e7 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -2,8 +2,6 @@ import decimal import json from typing import Optional, Union -from gunicorn.config import User - from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.llm_message import LLMMessage @@ -269,7 +267,7 @@ class ConversationMessageTask: class PubHandler: - def __init__(self, user: Union[Account | User], task_id: str, + def __init__(self, user: Union[Account | EndUser], task_id: str, message: Message, conversation: Conversation, chain_pub: bool = False, agent_thought_pub: bool = False): self._channel = PubHandler.generate_channel_name(user, task_id) @@ -282,12 +280,12 @@ class PubHandler: self._agent_thought_pub = agent_thought_pub @classmethod - def generate_channel_name(cls, user: Union[Account | User], task_id: str): + def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str): user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id return "generate_result:{}-{}".format(user_str, task_id) @classmethod - def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str): + def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id return "generate_result_stopped:{}-{}".format(user_str, task_id) @@ -366,7 +364,7 @@ class PubHandler: redis_client.publish(self._channel, json.dumps(content)) @classmethod - def pub_error(cls, user: Union[Account | User], task_id: str, e): + def pub_error(cls, user: Union[Account | EndUser], task_id: str, e): content = { 'error': type(e).__name__, 'description': e.description if getattr(e, 'description', None) is not None else str(e) @@ -379,7 +377,7 @@ class PubHandler: return redis_client.get(self._stopped_cache_key) is not None @classmethod - def stop(cls, user: Union[Account | User], task_id: str): + def stop(cls, user: Union[Account | EndUser], task_id: str): stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) redis_client.setex(stopped_cache_key, 600, 1)