refactor: optimize database usage (#12071)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-25 16:24:52 +08:00 committed by GitHub
parent b281a80150
commit 83ea931e3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 574 additions and 561 deletions

View File

@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping
from threading import Thread from threading import Thread
from typing import Any, Optional, Union from typing import Any, Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None _conversation_name_generate_thread: Optional[Thread] = None
@ -96,32 +97,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
) -> None: ) -> None:
""" super().__init__(
Initialize AdvancedChatAppGenerateTaskPipeline. application_generate_entity=application_generate_entity,
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param workflow: workflow stream=stream,
:param queue_manager: queue manager )
:param conversation: conversation
:param message: message
:param user: user
:param stream: stream
:param dialogue_count: dialogue count
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser): if isinstance(user, EndUser):
user_id = self._user.session_id self._user_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else: else:
user_id = self._user.id raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._workflow = workflow
self._conversation = conversation
self._message = message
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query, SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id, SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id, SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_ID: workflow.id,
@ -139,13 +143,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@ -171,12 +171,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return ChatbotAppBlockingResponse( return ChatbotAppBlockingResponse(
task_id=stream_response.task_id, task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data( data=ChatbotAppBlockingResponse.Data(
id=self._message.id, id=self._message_id,
mode=self._conversation.mode, mode=self._conversation_mode,
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
answer=self._task_state.answer, answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
**extras, **extras,
), ),
) )
@ -194,9 +194,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
for stream_response in generator: for stream_response in generator:
yield ChatbotAppStreamResponse( yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
stream_response=stream_response, stream_response=stream_response,
) )
@ -214,7 +214,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow_features_dict
if ( if (
features_dict.get("text_to_speech") features_dict.get("text_to_speech")
@ -274,26 +274,33 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message) with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
# init workflow run with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_start() # init workflow run
workflow_run = self._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
session.commit()
self._refetch_message() workflow_start_resp = self._workflow_start_to_stream_response(
self._message.workflow_run_id = workflow_run.id session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
db.session.commit() yield workflow_start_resp
db.session.refresh(self._message)
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance( elif isinstance(
event, event,
QueueNodeRetryEvent, QueueNodeRetryEvent,
@ -304,28 +311,28 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
response = self._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_node_retry_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
if response: if node_retry_resp:
yield response yield node_retry_resp
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not workflow_run:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response_start = self._workflow_node_start_to_stream_response( node_start_resp = self._workflow_node_start_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
if response_start: if node_start_resp:
yield response_start yield node_start_resp
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event) workflow_node_execution = self._handle_workflow_node_execution_success(event)
@ -333,25 +340,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if event.node_type in [NodeType.ANSWER, NodeType.END]: if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response_finish = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
if response_finish: if node_finish_resp:
yield response_finish yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event) workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response_finish = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_node_finish_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
if node_finish_resp:
if response: yield node_finish_resp
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not workflow_run:
@ -395,20 +401,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_run = self._handle_workflow_run_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run=workflow_run,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
conversation_id=self._conversation.id, total_steps=graph_runtime_state.node_run_steps,
trace_manager=trace_manager, outputs=event.outputs,
) conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run: if not workflow_run:
@ -417,21 +427,25 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_partial_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run=workflow_run,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
exceptions_count=event.exceptions_count, total_steps=graph_runtime_state.node_run_steps,
conversation_id=None, outputs=event.outputs,
trace_manager=trace_manager, exceptions_count=event.exceptions_count,
) conversation_id=None,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run: if not workflow_run:
@ -440,71 +454,73 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( with Session(db.engine) as session:
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run=workflow_run, workflow_run=workflow_run,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED, status=WorkflowRunStatus.FAILED,
error=event.get_stop_reason(), error=event.error,
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response(
yield self._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
session.commit()
yield workflow_finish_resp
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
# Save message workflow_finish_resp = self._workflow_finish_to_stream_response(
self._save_message(graph_runtime_state=graph_runtime_state) session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
yield workflow_finish_resp
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._handle_retriever_resources(event)
self._refetch_message() with Session(db.engine) as session:
message = self._get_message(session=session)
self._message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit()
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event) self._handle_annotation_reply(event)
self._refetch_message() with Session(db.engine) as session:
message = self._get_message(session=session)
self._message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit()
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
if delta_text is None: if delta_text is None:
@ -521,7 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_to_stream_response(
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
@ -536,7 +552,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message # Save message
self._save_message(graph_runtime_state=graph_runtime_state) with Session(db.engine) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
else: else:
@ -549,54 +567,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
self._refetch_message() message = self._get_message(session=session)
message.answer = self._task_state.answer
self._message.answer = self._task_state.answer message.provider_response_latency = time.perf_counter() - self._start_at
self._message.provider_response_latency = time.perf_counter() - self._start_at message.message_metadata = (
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=self._message.id, message_id=message.id,
type=file["type"], type=file["type"],
transfer_method=file["transfer_method"], transfer_method=file["transfer_method"],
url=file["remote_url"], url=file["remote_url"],
belongs_to="assistant", belongs_to="assistant",
upload_file_id=file["related_id"], upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT created_by_role=CreatedByRole.ACCOUNT
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER, else CreatedByRole.END_USER,
created_by=self._message.from_account_id or self._message.from_end_user_id or "", created_by=message.from_account_id or message.from_end_user_id or "",
) )
for file in self._recorded_files for file in self._recorded_files
] ]
db.session.add_all(message_files) session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage: if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit message.message_price_unit = usage.prompt_price_unit
self._message.answer_tokens = usage.completion_tokens message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
self._message.total_price = usage.total_price message.total_price = usage.total_price
self._message.currency = usage.currency message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage) self._task_state.metadata["usage"] = jsonable_encoder(usage)
else: else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
db.session.commit()
message_was_created.send( message_was_created.send(
self._message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
) )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -613,7 +623,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message.id, id=self._message_id,
files=self._recorded_files, files=self._recorded_files,
metadata=extras.get("metadata", {}), metadata=extras.get("metadata", {}),
) )
@ -641,11 +651,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return False return False
def _refetch_message(self) -> None: def _get_message(self, *, session: Session):
""" stmt = select(Message).where(Message.id == self._message_id)
Refetch message. message = session.scalar(stmt)
:return: if not message:
""" raise ValueError(f"Message not found: {self._message_id}")
message = db.session.query(Message).filter(Message.id == self._message.id).first() return message
if message:
self._message = message

View File

@ -70,7 +70,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user,
stream=stream, stream=stream,
) )

View File

@ -3,6 +3,8 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Any, Optional, Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -50,6 +52,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser from models.model import EndUser
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity _application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
@ -83,25 +84,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" super().__init__(
Initialize GenerateTaskPipeline. application_generate_entity=application_generate_entity,
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param workflow: workflow stream=stream,
:param queue_manager: queue manager )
:param user: user
:param stream: is streamed
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser): if isinstance(user, EndUser):
user_id = self._user.session_id self._user_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else: else:
user_id = self._user.id raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow = workflow
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id, SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
@ -115,10 +118,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
@ -185,7 +184,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow_features_dict
if ( if (
features_dict.get("text_to_speech") features_dict.get("text_to_speech")
@ -242,18 +241,26 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event) err = self._handle_error(event=event)
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
# init workflow run with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_start() # init workflow run
yield self._workflow_start_to_stream_response( workflow_run = self._handle_workflow_run_start(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session,
) workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
start_resp = self._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield start_resp
elif isinstance( elif isinstance(
event, event,
QueueNodeRetryEvent, QueueNodeRetryEvent,
@ -350,22 +357,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run=workflow_run,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
conversation_id=None, total_steps=graph_runtime_state.node_run_steps,
trace_manager=trace_manager, outputs=event.outputs,
) conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log # save workflow app log
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
yield self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session,
) task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run: if not workflow_run:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
@ -373,49 +386,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_partial_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run=workflow_run,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
exceptions_count=event.exceptions_count, total_steps=graph_runtime_state.node_run_steps,
conversation_id=None, outputs=event.outputs,
trace_manager=trace_manager, exceptions_count=event.exceptions_count,
) conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log # save workflow app log
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
yield self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run: if not workflow_run:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_failed(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run=workflow_run,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
status=WorkflowRunStatus.FAILED total_tokens=graph_runtime_state.total_tokens,
if isinstance(event, QueueWorkflowFailedEvent) total_steps=graph_runtime_state.node_run_steps,
else WorkflowRunStatus.STOPPED, status=WorkflowRunStatus.FAILED
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), if isinstance(event, QueueWorkflowFailedEvent)
conversation_id=None, else WorkflowRunStatus.STOPPED,
trace_manager=trace_manager, error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, conversation_id=None,
) trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log # save workflow app log
self._save_workflow_app_log(workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)
yield self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
if delta_text is None: if delta_text is None:
@ -435,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
""" """
Save workflow app log. Save workflow app log.
:return: :return:
@ -457,12 +479,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user.id workflow_app_log.created_by = self._user_id
db.session.add(workflow_app_log) session.add(workflow_app_log)
db.session.commit()
db.session.close()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None self, text: str, from_variable_selector: Optional[list[str]] = None

View File

@ -1,6 +1,9 @@
import logging import logging
import time import time
from typing import Optional, Union from typing import Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
@ -17,9 +20,7 @@ from core.app.entities.task_entities import (
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration from core.moderation.output_moderation import ModerationRule, OutputModeration
from extensions.ext_database import db from models.model import Message
from models.account import Account
from models.model import EndUser, Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline:
self, self,
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" """
@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline:
""" """
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._user = user
self._start_at = time.perf_counter() self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation() self._output_moderation_handler = self._init_output_moderation()
self._stream = stream self._stream = stream
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
"""
Handle error event.
:param event: event
:param message: message
:return:
"""
logger.debug("error: %s", event.error) logger.debug("error: %s", event.error)
e = event.error e = event.error
err: Exception err: Exception
@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline:
else: else:
err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
if message: if not message_id or not session:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first() return err
if refetch_message: stmt = select(Message).where(Message.id == message_id)
err_desc = self._error_to_desc(err) message = session.scalar(stmt)
refetch_message.status = "error" if not message:
refetch_message.error = err_desc return err
db.session.commit()
err_desc = self._error_to_desc(err)
message.status = "error"
message.error = err_desc
return err return err
def _error_to_desc(self, e: Exception) -> str: def _error_to_desc(self, e: Exception) -> str:

View File

@ -5,6 +5,9 @@ from collections.abc import Generator
from threading import Thread from threading import Thread
from typing import Optional, Union, cast from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.model import AppMode, Conversation, Message, MessageAgentThought
from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" super().__init__(
Initialize GenerateTaskPipeline. application_generate_entity=application_generate_entity,
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param queue_manager: queue manager stream=stream,
:param conversation: conversation )
:param message: message
:param user: user
:param stream: stream
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
self._model_config = application_generate_entity.model_conf self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config self._app_config = application_generate_entity.app_config
self._conversation = conversation
self._message = message self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._task_state = EasyUITaskState( self._task_state = EasyUITaskState(
llm_result=LLMResult( llm_result=LLMResult(
@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]: ]:
"""
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query or "" conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata extras["metadata"] = self._task_state.metadata
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation.mode == AppMode.COMPLETION.value: if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data( data=CompletionAppBlockingResponse.Data(
id=self._message.id, id=self._message_id,
mode=self._conversation.mode, mode=self._conversation_mode,
message_id=self._message.id, message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content), answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
**extras, **extras,
), ),
) )
@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
response = ChatbotAppBlockingResponse( response = ChatbotAppBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
data=ChatbotAppBlockingResponse.Data( data=ChatbotAppBlockingResponse.Data(
id=self._message.id, id=self._message_id,
mode=self._conversation.mode, mode=self._conversation_mode,
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content), answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
**extras, **extras,
), ),
) )
@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
for stream_response in generator: for stream_response in generator:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
yield CompletionAppStreamResponse( yield CompletionAppStreamResponse(
message_id=self._message.id, message_id=self._message_id,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
stream_response=stream_response, stream_response=stream_response,
) )
else: else:
yield ChatbotAppStreamResponse( yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
stream_response=stream_response, stream_response=stream_response,
) )
@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
event = message.event event = message.event
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message) with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = output_moderation_answer self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message with Session(db.engine) as session:
self._save_message(trace_manager) # Save message
self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response() session.commit()
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = current_content self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent): if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(cast(str, delta_text), self._message.id) yield self._message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
else: else:
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) yield self._agent_message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent): elif isinstance(event, QueuePingEvent):
@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
""" """
Save message. Save message.
:return: :return:
@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
llm_result = self._task_state.llm_result llm_result = self._task_state.llm_result
usage = llm_result.usage usage = llm_result.usage
message = db.session.query(Message).filter(Message.id == self._message.id).first() message_stmt = select(Message).where(Message.id == self._message_id)
message = session.scalar(message_stmt)
if not message: if not message:
raise Exception(f"Message {self._message.id} not found") raise ValueError(f"message {self._message_id} not found")
self._message = message conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() conversation = session.scalar(conversation_stmt)
if not conversation: if not conversation:
raise Exception(f"Conversation {self._conversation.id} not found") raise ValueError(f"Conversation {self._conversation_id} not found")
self._conversation = conversation
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages self._model_config.mode, self._task_state.llm_result.prompt_messages
) )
self._message.message_tokens = usage.prompt_tokens message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit message.message_price_unit = usage.prompt_price_unit
self._message.answer = ( message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content if llm_result.message.content
else "" else ""
) )
self._message.answer_tokens = usage.completion_tokens message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price message.total_price = usage.total_price
self._message.currency = usage.currency message.currency = usage.currency
self._message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
db.session.commit()
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
) )
) )
message_was_created.send( message_was_created.send(
self._message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and hasattr(self._application_generate_entity, "conversation_id")
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
) )
def _handle_stop(self, event: QueueStopEvent) -> None: def _handle_stop(self, event: QueueStopEvent) -> None:
@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message.id, id=self._message_id,
metadata=extras.get("metadata", {}), metadata=extras.get("metadata", {}),
) )

View File

@ -36,7 +36,7 @@ class MessageCycleManage:
] ]
_task_state: Union[EasyUITaskState, WorkflowTaskState] _task_state: Union[EasyUITaskState, WorkflowTaskState]
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """
Generate conversation name. Generate conversation name.
:param conversation: conversation :param conversation: conversation
@ -56,7 +56,7 @@ class MessageCycleManage:
target=self._generate_conversation_name_worker, target=self._generate_conversation_name_worker,
kwargs={ kwargs={
"flask_app": current_app._get_current_object(), # type: ignore "flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation.id, "conversation_id": conversation_id,
"query": query, "query": query,
}, },
) )

View File

@ -5,6 +5,7 @@ from datetime import UTC, datetime
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
@ -63,27 +64,34 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
class WorkflowCycleManage: class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
def _handle_workflow_run_start(self) -> WorkflowRun: def _handle_workflow_run_start(
max_sequence = ( self,
db.session.query(db.func.max(WorkflowRun.sequence_number)) *,
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id) session: Session,
.filter(WorkflowRun.app_id == self._workflow.app_id) workflow_id: str,
.scalar() user_id: str,
or 0 created_by_role: CreatedByRole,
) -> WorkflowRun:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.tenant_id == workflow.tenant_id,
WorkflowRun.app_id == workflow.app_id,
) )
max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1 new_sequence_number = max_sequence + 1
inputs = {**self._application_generate_entity.inputs} inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items(): for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation": if key.value == "conversation":
continue continue
inputs[f"sys.{key.value}"] = value inputs[f"sys.{key.value}"] = value
triggered_from = ( triggered_from = (
@ -96,33 +104,32 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run # init workflow run
with Session(db.engine, expire_on_commit=False) as session: workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
workflow_run = WorkflowRun()
system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
workflow_run.id = system_id or str(uuid4())
workflow_run.tenant_id = self._workflow.tenant_id
workflow_run.app_id = self._workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id
workflow_run.type = self._workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version
workflow_run.graph = self._workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = (
CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
)
workflow_run.created_by = self._user.id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_run) workflow_run = WorkflowRun()
session.commit() workflow_run.id = workflow_run_id
workflow_run.tenant_id = workflow.tenant_id
workflow_run.app_id = workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = workflow.id
workflow_run.type = workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = workflow.version
workflow_run.graph = workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = created_by_role
workflow_run.created_by = user_id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_run)
return workflow_run return workflow_run
def _handle_workflow_run_success( def _handle_workflow_run_success(
self, self,
*,
session: Session,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
@ -141,7 +148,7 @@ class WorkflowCycleManage:
:param conversation_id: conversation id :param conversation_id: conversation id
:return: :return:
""" """
workflow_run = self._refetch_workflow_run(workflow_run.id) workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs) outputs = WorkflowEntry.handle_special_values(outputs)
@ -152,9 +159,6 @@ class WorkflowCycleManage:
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
@ -165,12 +169,12 @@ class WorkflowCycleManage:
) )
) )
db.session.close()
return workflow_run return workflow_run
def _handle_workflow_run_partial_success( def _handle_workflow_run_partial_success(
self, self,
*,
session: Session,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
@ -190,7 +194,7 @@ class WorkflowCycleManage:
:param conversation_id: conversation id :param conversation_id: conversation id
:return: :return:
""" """
workflow_run = self._refetch_workflow_run(workflow_run.id) workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
@ -201,8 +205,6 @@ class WorkflowCycleManage:
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -214,12 +216,12 @@ class WorkflowCycleManage:
) )
) )
db.session.close()
return workflow_run return workflow_run
def _handle_workflow_run_failed( def _handle_workflow_run_failed(
self, self,
*,
session: Session,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
@ -240,7 +242,7 @@ class WorkflowCycleManage:
:param error: error message :param error: error message
:return: :return:
""" """
workflow_run = self._refetch_workflow_run(workflow_run.id) workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id)
workflow_run.status = status.value workflow_run.status = status.value
workflow_run.error = error workflow_run.error = error
@ -249,21 +251,18 @@ class WorkflowCycleManage:
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
db.session.commit()
running_workflow_node_executions = ( stmt = select(WorkflowNodeExecution).where(
db.session.query(WorkflowNodeExecution) WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
.filter( WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.all()
) )
running_workflow_node_executions = session.scalars(stmt).all()
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
@ -271,13 +270,6 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = ( workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds() ).total_seconds()
db.session.commit()
db.session.close()
# with Session(db.engine, expire_on_commit=False) as session:
# session.add(workflow_run)
# session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -485,14 +477,14 @@ class WorkflowCycleManage:
################################################# #################################################
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow start to stream response. _ = session
:param task_id: task id
:param workflow_run: workflow run
:return:
"""
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -506,36 +498,32 @@ class WorkflowCycleManage:
) )
def _workflow_finish_to_stream_response( def _workflow_finish_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowFinishStreamResponse: ) -> WorkflowFinishStreamResponse:
"""
Workflow finish to stream response.
:param task_id: task id
:param workflow_run: workflow run
:return:
"""
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
workflow_run = db.session.merge(workflow_run)
# Refresh to ensure any expired attributes are fully loaded
db.session.refresh(workflow_run)
created_by = None created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
created_by_account = workflow_run.created_by_account stmt = select(Account).where(Account.id == workflow_run.created_by)
if created_by_account: account = session.scalar(stmt)
if account:
created_by = { created_by = {
"id": created_by_account.id, "id": account.id,
"name": created_by_account.name, "name": account.name,
"email": created_by_account.email, "email": account.email,
}
elif workflow_run.created_by_role == CreatedByRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
} }
else: else:
created_by_end_user = workflow_run.created_by_end_user raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
if created_by_end_user:
created_by = {
"id": created_by_end_user.id,
"user": created_by_end_user.session_id,
}
return WorkflowFinishStreamResponse( return WorkflowFinishStreamResponse(
task_id=task_id, task_id=task_id,
@ -895,14 +883,14 @@ class WorkflowCycleManage:
return None return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
""" """
Refetch workflow run Refetch workflow run
:param workflow_run_id: workflow run id :param workflow_run_id: workflow run id
:return: :return:
""" """
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run: if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id) raise WorkflowRunNotFoundError(workflow_run_id)

View File

@ -9,6 +9,8 @@ from typing import Any, Optional, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from flask import current_app from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import ( from core.ops.entities.config_entity import (
@ -329,15 +331,15 @@ class TraceTask:
): ):
self.trace_type = trace_type self.trace_type = trace_type
self.message_id = message_id self.message_id = message_id
self.workflow_run = workflow_run self.workflow_run_id = workflow_run.id if workflow_run else None
self.conversation_id = conversation_id self.conversation_id = conversation_id
self.user_id = user_id self.user_id = user_id
self.timer = timer self.timer = timer
self.kwargs = kwargs
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.app_id = None self.app_id = None
self.kwargs = kwargs
def execute(self): def execute(self):
return self.preprocess() return self.preprocess()
@ -345,19 +347,23 @@ class TraceTask:
preprocess_map = { preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
self.workflow_run, self.conversation_id, self.user_id workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
), ),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
self.message_id, self.timer, **self.kwargs message_id=self.message_id, timer=self.timer, **self.kwargs
), ),
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
self.message_id, self.timer, **self.kwargs message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
), ),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
self.conversation_id, self.timer, **self.kwargs conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
), ),
} }
@ -367,86 +373,100 @@ class TraceTask:
def conversation_trace(self, **kwargs): def conversation_trace(self, **kwargs):
return kwargs return kwargs
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): def workflow_trace(
if not workflow_run: self,
raise ValueError("Workflow run not found") *,
workflow_run_id: str | None,
conversation_id: str | None,
user_id: str | None,
):
if not workflow_run_id:
return {}
db.session.merge(workflow_run) with Session(db.engine) as session:
db.session.refresh(workflow_run) workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalars(workflow_run_stmt).first()
if not workflow_run:
raise ValueError("Workflow run not found")
workflow_id = workflow_run.workflow_id workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version workflow_run_version = workflow_run.version
error = workflow_run.error or "" error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or [] file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id # get workflow_app_log_id
workflow_app_log_data = ( workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
db.session.query(WorkflowAppLog) WorkflowAppLog.tenant_id == tenant_id,
.filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) WorkflowAppLog.app_id == workflow_run.app_id,
.first() WorkflowAppLog.workflow_run_id == workflow_run.id,
) )
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
# get message_id # get message_id
message_data = ( message_id = None
db.session.query(Message.id) if conversation_id:
.filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) message_data_stmt = select(Message.id).where(
.first() Message.conversation_id == conversation_id,
) Message.workflow_run_id == workflow_run_id,
message_id = str(message_data.id) if message_data else None )
message_id = session.scalar(message_data_stmt)
metadata = { metadata = {
"workflow_id": workflow_id, "workflow_id": workflow_id,
"conversation_id": conversation_id, "conversation_id": conversation_id,
"workflow_run_id": workflow_run_id, "workflow_run_id": workflow_run_id,
"tenant_id": tenant_id, "tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time, "elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status, "status": workflow_run_status,
"version": workflow_run_version, "version": workflow_run_version,
"total_tokens": total_tokens, "total_tokens": total_tokens,
"file_list": file_list, "file_list": file_list,
"triggered_form": workflow_run.triggered_from, "triggered_form": workflow_run.triggered_from,
"user_id": user_id, "user_id": user_id,
} }
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
return workflow_trace_info return workflow_trace_info
def message_trace(self, message_id): def message_trace(self, message_id: str | None):
if not message_id:
return {}
message_data = get_message_data(message_id) message_data = get_message_data(message_id)
if not message_data: if not message_data:
return {} return {}
conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
return {}
conversation_mode = conversation_mode[0] conversation_mode = conversation_mode[0]
created_at = message_data.created_at created_at = message_data.created_at
inputs = message_data.message inputs = message_data.message

View File

@ -18,7 +18,7 @@ def filter_none_values(data: dict):
return new_data return new_data
def get_message_data(message_id): def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first() return db.session.query(Message).filter(Message.id == message_id).first()

View File

@ -3,6 +3,7 @@ import json
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from .engine import db from .engine import db
from .types import StringUUID from .types import StringUUID
@ -20,7 +21,7 @@ class Account(UserMixin, db.Model): # type: ignore[name-defined]
__tablename__ = "accounts" __tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=True) password = db.Column(db.String(255), nullable=True)

View File

@ -530,13 +530,13 @@ class Conversation(db.Model): # type: ignore[name-defined]
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True) app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
mode = db.Column(db.String(255), nullable=False) mode: Mapped[str] = mapped_column(db.String(255))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
summary = db.Column(db.Text) summary = db.Column(db.Text)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
@ -770,7 +770,7 @@ class Message(db.Model): # type: ignore[name-defined]
db.Index("message_created_at_idx", "created_at"), db.Index("message_created_at_idx", "created_at"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
@ -797,7 +797,7 @@ class Message(db.Model): # type: ignore[name-defined]
from_source = db.Column(db.String(255), nullable=False) from_source = db.Column(db.String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
workflow_run_id = db.Column(StringUUID) workflow_run_id = db.Column(StringUUID)
@ -1322,7 +1322,7 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined]
external_user_id = db.Column(db.String(255), nullable=True) external_user_id = db.Column(db.String(255), nullable=True)
name = db.Column(db.String(255)) name = db.Column(db.String(255))
is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
session_id = db.Column(db.String(255), nullable=False) session_id: Mapped[str] = mapped_column()
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -392,40 +392,28 @@ class WorkflowRun(db.Model): # type: ignore[name-defined]
db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID)
sequence_number = db.Column(db.Integer, nullable=False) sequence_number: Mapped[int] = mapped_column()
workflow_id = db.Column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID)
type = db.Column(db.String(255), nullable=False) type: Mapped[str] = mapped_column(db.String(255))
triggered_from = db.Column(db.String(255), nullable=False) triggered_from: Mapped[str] = mapped_column(db.String(255))
version = db.Column(db.String(255), nullable=False) version: Mapped[str] = mapped_column(db.String(255))
graph = db.Column(db.Text) graph: Mapped[str] = mapped_column(db.Text)
inputs = db.Column(db.Text) inputs: Mapped[str] = mapped_column(db.Text)
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error = db.Column(db.Text) error: Mapped[str] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime) finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0")) exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property
def created_by_end_user(self):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@property @property
def graph_dict(self): def graph_dict(self):
return json.loads(self.graph) if self.graph else {} return json.loads(self.graph) if self.graph else {}
@ -750,11 +738,11 @@ class WorkflowAppLog(db.Model): # type: ignore[name-defined]
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False)
workflow_run_id = db.Column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False) created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)