diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 1073a0f2e4..691d178ba2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping from threading import Thread from typing import Any, Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow: Workflow - _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _conversation_name_generate_thread: Optional[Thread] = None @@ -96,32 +97,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream: bool, dialogue_count: int, ) -> None: - """ - Initialize AdvancedChatAppGenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - :param dialogue_count: dialogue count - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise NotImplementedError(f"User type not supported: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) - self._workflow = workflow - self._conversation = conversation - self._message = message self._workflow_system_variables = { SystemVariableKey.QUERY: message.query, SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: self._user_id, SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, @@ -139,13 +143,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -171,12 +171,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return ChatbotAppBlockingResponse( task_id=stream_response.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=self._task_state.answer, - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -194,9 +194,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ for stream_response in generator: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -214,7 +214,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -274,26 +274,33 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + message.workflow_run_id = workflow_run.id + session.commit() - self._refetch_message() - self._message.workflow_run_id = workflow_run.id - - db.session.commit() - db.session.refresh(self._message) - db.session.close() - - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + yield workflow_start_resp elif isinstance( event, QueueNodeRetryEvent, @@ -304,28 +311,28 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_run=workflow_run, event=event ) - response = self._workflow_node_retry_to_stream_response( + node_retry_resp = self._workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if node_retry_resp: + yield node_retry_resp elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - response_start = self._workflow_node_start_to_stream_response( + node_start_resp = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response_start: - yield response_start + if node_start_resp: + yield node_start_resp elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) @@ -333,25 +340,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response_finish = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response_finish: - yield response_finish + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response_finish = self._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - - if response: - yield response + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: @@ -395,20 +401,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if not graph_runtime_state: raise ValueError("workflow run not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield workflow_finish_resp self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: @@ -417,21 +427,25 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield workflow_finish_resp self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: @@ -440,71 +454,73 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - elif isinstance(event, QueueStopEvent): - if workflow_run and graph_runtime_state: + with Session(db.engine) as session: workflow_run = self._handle_workflow_run_failed( + session=session, workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), - conversation_id=self._conversation.id, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation_id, trace_manager=trace_manager, + exceptions_count=event.exceptions_count, ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err = self._handle_error(event=err_event, session=session, message_id=self._message_id) + session.commit() + yield workflow_finish_resp + yield self._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent): + if workflow_run and graph_runtime_state: + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) - # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + yield workflow_finish_resp yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -521,7 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._task_state.answer += delta_text yield self._message_to_stream_response( - answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation @@ -536,7 +552,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + with Session(db.engine) as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() yield self._message_end_to_stream_response() else: @@ -549,54 +567,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - self._refetch_message() - - self._message.answer = self._task_state.answer - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = ( + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + message = self._get_message(session=session) + message.answer = self._task_state.answer + message.provider_response_latency = time.perf_counter() - self._start_at + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) message_files = [ MessageFile( - message_id=self._message.id, + message_id=message.id, type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], created_by_role=CreatedByRole.ACCOUNT - if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else CreatedByRole.END_USER, - created_by=self._message.from_account_id or self._message.from_end_user_id or "", + created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files ] - db.session.add_all(message_files) + session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.total_price = usage.total_price - self._message.currency = usage.currency - + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.total_price = usage.total_price + message.currency = usage.currency self._task_state.metadata["usage"] = jsonable_encoder(usage) else: self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) - - db.session.commit() - message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -613,7 +623,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, files=self._recorded_files, metadata=extras.get("metadata", {}), ) @@ -641,11 +651,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return False - def _refetch_message(self) -> None: - """ - Refetch message. - :return: - """ - message = db.session.query(Message).filter(Message.id == self._message.id).first() - if message: - self._message = message + def _get_message(self, *, session: Session): + stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(stmt) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + return message diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index c2e35faf89..dcd9463b8a 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -70,7 +70,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): queue_manager=queue_manager, conversation=conversation, message=message, - user=user, stream=stream, ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c47b38f560..574596d4f5 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,6 +3,8 @@ import time from collections.abc import Generator from typing import Any, Optional, Union +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager @@ -50,6 +52,7 @@ from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import EndUser from models.workflow import ( Workflow, @@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariableKey, Any] @@ -83,25 +84,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param user: user - :param stream: is streamed - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict - self._workflow = workflow self._workflow_system_variables = { SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id, + SystemVariableKey.USER_ID: self._user_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, @@ -115,10 +118,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) @@ -185,7 +184,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -242,18 +241,26 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(event, QueuePingEvent): yield self._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event) + err = self._handle_error(event=event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine) as session: + # init workflow run + workflow_run = self._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + start_resp = self._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield start_resp elif isinstance( event, QueueNodeRetryEvent, @@ -350,22 +357,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + session.commit() + yield workflow_finish_resp elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: raise ValueError("workflow run not initialized.") @@ -373,49 +386,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_partial_success( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: raise ValueError("workflow run not initialized.") if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) + with Session(db.engine) as session: + workflow_run = self._handle_workflow_run_failed( + session=session, + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) - # save workflow app log - self._save_workflow_app_log(workflow_run) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + yield workflow_finish_resp elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -435,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: + def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: @@ -457,12 +479,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" - workflow_app_log.created_by = self._user.id + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id - db.session.add(workflow_app_log) - db.session.commit() - db.session.close() + session.add(workflow_app_log) def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 03a81353d0..e363a7f642 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,9 @@ import logging import time -from typing import Optional, Union +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ( @@ -17,9 +20,7 @@ from core.app.entities.task_entities import ( from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, Message +from models.model import Message logger = logging.getLogger(__name__) @@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline: self, application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, - user: Union[Account, EndUser], stream: bool, ) -> None: """ @@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline: """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager - self._user = user self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): - """ - Handle error event. - :param event: event - :param message: message - :return: - """ + def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline: else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) - if message: - refetch_message = db.session.query(Message).filter(Message.id == message.id).first() + if not message_id or not session: + return err - if refetch_message: - err_desc = self._error_to_desc(err) - refetch_message.status = "error" - refetch_message.error = err_desc - - db.session.commit() + stmt = select(Message).where(Message.id == message_id) + message = session.scalar(stmt) + if not message: + return err + err_desc = self._error_to_desc(err) + message.status = "error" + message.error = err_desc return err def _error_to_desc(self, e: Exception) -> str: diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b9f8e7ca56..c84f8ba3e4 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -5,6 +5,9 @@ from collections.abc import Generator from threading import Thread from typing import Optional, Union, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.account import Account -from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan queue_manager: AppQueueManager, conversation: Conversation, message: Message, - user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) self._model_config = application_generate_entity.model_conf self._app_config = application_generate_entity.app_config - self._conversation = conversation - self._message = message + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) self._task_state = EasyUITaskState( llm_result=LLMResult( @@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan CompletionAppBlockingResponse, Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._conversation) - db.session.refresh(self._message) - db.session.close() - if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query or "" + conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._task_state.metadata: extras["metadata"] = self._task_state.metadata response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation.mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan response = ChatbotAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=cast(str, self._task_state.llm_result.message.content), - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan for stream_response in generator: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): yield CompletionAppStreamResponse( - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = output_moderation_answer yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message(trace_manager) - - yield self._message_end_to_stream_response() + with Session(db.engine) as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + session.commit() + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): @@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) else: - yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan llm_result = self._task_state.llm_result usage = llm_result.usage - message = db.session.query(Message).filter(Message.id == self._message.id).first() + message_stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(message_stmt) if not message: - raise Exception(f"Message {self._message.id} not found") - self._message = message - conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + raise ValueError(f"message {self._message_id} not found") + conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) + conversation = session.scalar(conversation_stmt) if not conversation: - raise Exception(f"Conversation {self._conversation.id} not found") - self._conversation = conversation + raise ValueError(f"Conversation {self._conversation_id} not found") - self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages ) - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = ( + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer = ( PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.total_price = usage.total_price - self._message.currency = usage.currency - self._message.message_metadata = ( + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.provider_response_latency = time.perf_counter() - self._start_at + message.total_price = usage.total_price + message.currency = usage.currency + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) - db.session.commit() - if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id ) ) message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} - and hasattr(self._application_generate_entity, "conversation_id") - and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, - id=self._message.id, + id=self._message_id, metadata=extras.get("metadata", {}), ) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 007543f6d0..15f2c25c66 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -36,7 +36,7 @@ class MessageCycleManage: ] _task_state: Union[EasyUITaskState, WorkflowTaskState] - def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation: conversation @@ -56,7 +56,7 @@ class MessageCycleManage: target=self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "conversation_id": conversation.id, + "conversation_id": conversation_id, "query": query, }, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index f581e564f2..2692008c66 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -5,6 +5,7 @@ from datetime import UTC, datetime from typing import Any, Optional, Union, cast from uuid import uuid4 +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -63,27 +64,34 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - def _handle_workflow_run_start(self) -> WorkflowRun: - max_sequence = ( - db.session.query(db.func.max(WorkflowRun.sequence_number)) - .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) - .filter(WorkflowRun.app_id == self._workflow.app_id) - .scalar() - or 0 + def _handle_workflow_run_start( + self, + *, + session: Session, + workflow_id: str, + user_id: str, + created_by_role: CreatedByRole, + ) -> WorkflowRun: + workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.scalar(workflow_stmt) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + + max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( + WorkflowRun.tenant_id == workflow.tenant_id, + WorkflowRun.app_id == workflow.app_id, ) + max_sequence = session.scalar(max_sequence_stmt) or 0 new_sequence_number = max_sequence + 1 inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": continue - inputs[f"sys.{key.value}"] = value triggered_from = ( @@ -96,33 +104,32 @@ class WorkflowCycleManage: inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = WorkflowRun() - system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - workflow_run.id = system_id or str(uuid4()) - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER - ) - workflow_run.created_by = self._user.id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) - session.add(workflow_run) - session.commit() + workflow_run = WorkflowRun() + workflow_run.id = workflow_run_id + workflow_run.tenant_id = workflow.tenant_id + workflow_run.app_id = workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = workflow.id + workflow_run.type = workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = workflow.version + workflow_run.graph = workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = created_by_role + workflow_run.created_by = user_id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) return workflow_run def _handle_workflow_run_success( self, + *, + session: Session, workflow_run: WorkflowRun, start_at: float, total_tokens: int, @@ -141,7 +148,7 @@ class WorkflowCycleManage: :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) outputs = WorkflowEntry.handle_special_values(outputs) @@ -152,9 +159,6 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - db.session.refresh(workflow_run) - if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -165,12 +169,12 @@ class WorkflowCycleManage: ) ) - db.session.close() - return workflow_run def _handle_workflow_run_partial_success( self, + *, + session: Session, workflow_run: WorkflowRun, start_at: float, total_tokens: int, @@ -190,7 +194,7 @@ class WorkflowCycleManage: :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) @@ -201,8 +205,6 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - db.session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -214,12 +216,12 @@ class WorkflowCycleManage: ) ) - db.session.close() - return workflow_run def _handle_workflow_run_failed( self, + *, + session: Session, workflow_run: WorkflowRun, start_at: float, total_tokens: int, @@ -240,7 +242,7 @@ class WorkflowCycleManage: :param error: error message :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) workflow_run.status = status.value workflow_run.error = error @@ -249,21 +251,18 @@ class WorkflowCycleManage: workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - running_workflow_node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, - ) - .all() + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) + running_workflow_node_executions = session.scalars(stmt).all() + for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error @@ -271,13 +270,6 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() - db.session.commit() - - db.session.close() - - # with Session(db.engine, expire_on_commit=False) as session: - # session.add(workflow_run) - # session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -485,14 +477,14 @@ class WorkflowCycleManage: ################################################# def _workflow_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - """ - Workflow start to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return WorkflowStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -506,36 +498,32 @@ class WorkflowCycleManage: ) def _workflow_finish_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: - """ - Workflow finish to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ - # Attach WorkflowRun to an active session so "created_by_role" can be accessed. - workflow_run = db.session.merge(workflow_run) - - # Refresh to ensure any expired attributes are fully loaded - db.session.refresh(workflow_run) - created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: - created_by_account = workflow_run.created_by_account - if created_by_account: + if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + "id": account.id, + "name": account.name, + "email": account.email, + } + elif workflow_run.created_by_role == CreatedByRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: + created_by = { + "id": end_user.id, + "user": end_user.session_id, } else: - created_by_end_user = workflow_run.created_by_end_user - if created_by_end_user: - created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, - } + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") return WorkflowFinishStreamResponse( task_id=task_id, @@ -895,14 +883,14 @@ class WorkflowCycleManage: return None - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run :param workflow_run_id: workflow run id :return: """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() - + stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) if not workflow_run: raise WorkflowRunNotFoundError(workflow_run_id) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index f538eaef5b..691cb8d400 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -9,6 +9,8 @@ from typing import Any, Optional, Union from uuid import UUID, uuid4 from flask import current_app +from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( @@ -329,15 +331,15 @@ class TraceTask: ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run = workflow_run + self.workflow_run_id = workflow_run.id if workflow_run else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer - self.kwargs = kwargs self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - self.app_id = None + self.kwargs = kwargs + def execute(self): return self.preprocess() @@ -345,19 +347,23 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - self.workflow_run, self.conversation_id, self.user_id + workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + ), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs ), TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs + ), + TraceTaskName.TOOL_TRACE: lambda: self.tool_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( - self.conversation_id, self.timer, **self.kwargs + conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), } @@ -367,86 +373,100 @@ class TraceTask: def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): - if not workflow_run: - raise ValueError("Workflow run not found") + def workflow_trace( + self, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + ): + if not workflow_run_id: + return {} - db.session.merge(workflow_run) - db.session.refresh(workflow_run) + with Session(db.engine) as session: + workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalars(workflow_run_stmt).first() + if not workflow_run: + raise ValueError("Workflow run not found") - workflow_id = workflow_run.workflow_id - tenant_id = workflow_run.tenant_id - workflow_run_id = workflow_run.id - workflow_run_elapsed_time = workflow_run.elapsed_time - workflow_run_status = workflow_run.status - workflow_run_inputs = workflow_run.inputs_dict - workflow_run_outputs = workflow_run.outputs_dict - workflow_run_version = workflow_run.version - error = workflow_run.error or "" + workflow_id = workflow_run.workflow_id + tenant_id = workflow_run.tenant_id + workflow_run_id = workflow_run.id + workflow_run_elapsed_time = workflow_run.elapsed_time + workflow_run_status = workflow_run.status + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict + workflow_run_version = workflow_run.version + error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = workflow_run.total_tokens - file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" + file_list = workflow_run_inputs.get("sys.file") or [] + query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" - # get workflow_app_log_id - workflow_app_log_data = ( - db.session.query(WorkflowAppLog) - .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) - .first() - ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None - # get message_id - message_data = ( - db.session.query(Message.id) - .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) - .first() - ) - message_id = str(message_data.id) if message_data else None + # get workflow_app_log_id + workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.app_id == workflow_run.app_id, + WorkflowAppLog.workflow_run_id == workflow_run.id, + ) + workflow_app_log_id = session.scalar(workflow_app_log_data_stmt) + # get message_id + message_id = None + if conversation_id: + message_data_stmt = select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_run_id, + ) + message_id = session.scalar(message_data_stmt) - metadata = { - "workflow_id": workflow_id, - "conversation_id": conversation_id, - "workflow_run_id": workflow_run_id, - "tenant_id": tenant_id, - "elapsed_time": workflow_run_elapsed_time, - "status": workflow_run_status, - "version": workflow_run_version, - "total_tokens": total_tokens, - "file_list": file_list, - "triggered_form": workflow_run.triggered_from, - "user_id": user_id, - } - - workflow_trace_info = WorkflowTraceInfo( - workflow_data=workflow_run.to_dict(), - conversation_id=conversation_id, - workflow_id=workflow_id, - tenant_id=tenant_id, - workflow_run_id=workflow_run_id, - workflow_run_elapsed_time=workflow_run_elapsed_time, - workflow_run_status=workflow_run_status, - workflow_run_inputs=workflow_run_inputs, - workflow_run_outputs=workflow_run_outputs, - workflow_run_version=workflow_run_version, - error=error, - total_tokens=total_tokens, - file_list=file_list, - query=query, - metadata=metadata, - workflow_app_log_id=workflow_app_log_id, - message_id=message_id, - start_time=workflow_run.created_at, - end_time=workflow_run.finished_at, - ) + metadata = { + "workflow_id": workflow_id, + "conversation_id": conversation_id, + "workflow_run_id": workflow_run_id, + "tenant_id": tenant_id, + "elapsed_time": workflow_run_elapsed_time, + "status": workflow_run_status, + "version": workflow_run_version, + "total_tokens": total_tokens, + "file_list": file_list, + "triggered_form": workflow_run.triggered_from, + "user_id": user_id, + } + workflow_trace_info = WorkflowTraceInfo( + workflow_data=workflow_run.to_dict(), + conversation_id=conversation_id, + workflow_id=workflow_id, + tenant_id=tenant_id, + workflow_run_id=workflow_run_id, + workflow_run_elapsed_time=workflow_run_elapsed_time, + workflow_run_status=workflow_run_status, + workflow_run_inputs=workflow_run_inputs, + workflow_run_outputs=workflow_run_outputs, + workflow_run_version=workflow_run_version, + error=error, + total_tokens=total_tokens, + file_list=file_list, + query=query, + metadata=metadata, + workflow_app_log_id=workflow_app_log_id, + message_id=message_id, + start_time=workflow_run.created_at, + end_time=workflow_run.finished_at, + ) return workflow_trace_info - def message_trace(self, message_id): + def message_trace(self, message_id: str | None): + if not message_id: + return {} message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() + conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) + conversation_mode = db.session.scalars(conversation_mode_stmt).all() + if not conversation_mode or len(conversation_mode) == 0: + return {} conversation_mode = conversation_mode[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 998eba9ea9..8b06df1930 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -18,7 +18,7 @@ def filter_none_values(data: dict): return new_data -def get_message_data(message_id): +def get_message_data(message_id: str): return db.session.query(Message).filter(Message.id == message_id).first() diff --git a/api/models/account.py b/api/models/account.py index 88c96da1a1..35a28df750 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,6 +3,7 @@ import json from flask_login import UserMixin # type: ignore from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column from .engine import db from .types import StringUUID @@ -20,7 +21,7 @@ class Account(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) diff --git a/api/models/model.py b/api/models/model.py index 2a593f0829..d2d4d5853f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -530,13 +530,13 @@ class Conversation(db.Model): # type: ignore[name-defined] db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) model_id = db.Column(db.String(255), nullable=True) - mode = db.Column(db.String(255), nullable=False) + mode: Mapped[str] = mapped_column(db.String(255)) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -770,7 +770,7 @@ class Message(db.Model): # type: ignore[name-defined] db.Index("message_created_at_idx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -797,7 +797,7 @@ class Message(db.Model): # type: ignore[name-defined] from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1322,7 +1322,7 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined] external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - session_id = db.Column(db.String(255), nullable=False) + session_id: Mapped[str] = mapped_column() created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 880e044d07..78a7f8169f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -392,40 +392,28 @@ class WorkflowRun(db.Model): # type: ignore[name-defined] db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - sequence_number = db.Column(db.Integer, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) + type: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + version: Mapped[str] = mapped_column(db.String(255)) + graph: Mapped[str] = mapped_column(db.Text) + inputs: Mapped[str] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error = db.Column(db.Text) + error: Mapped[str] = mapped_column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) # account, end_user + created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) - @property - def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None - - @property - def created_by_end_user(self): - from models.model import EndUser - - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None - @property def graph_dict(self): return json.loads(self.graph) if self.graph else {} @@ -750,11 +738,11 @@ class WorkflowAppLog(db.Model): # type: ignore[name-defined] db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) workflow_id = db.Column(StringUUID, nullable=False) - workflow_run_id = db.Column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False)