diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 965c0c36ad..3ebc06e605 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -3,7 +3,7 @@ cd web && npm install echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc -echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc +echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc diff --git a/.vscode/launch.json b/.vscode/launch.json index 55fdbb8b50..03b15e7f27 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,7 +37,19 @@ "FLASK_DEBUG": "1", "GEVENT_SUPPORT": "True" }, - "args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"], + "args": [ + "-A", + "app.celery", + "worker", + "-P", + "gevent", + "-c", + "1", + "--loglevel", + "info", + "-Q", + "dataset,generation,mail,ops_trace" + ] }, ] } \ No newline at end of file diff --git a/api/README.md b/api/README.md index 5f71dbe5f0..125cd8a78c 100644 --- a/api/README.md +++ b/api/README.md @@ -66,7 +66,7 @@ 10. If you need to debug local async processing, please start the worker service. ```bash - poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail + poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace ``` The started celery app handles the async tasks, e.g. dataset importing and documents indexing. diff --git a/api/app.py b/api/app.py index 2c9b59706b..2ea7c6d235 100644 --- a/api/app.py +++ b/api/app.py @@ -26,7 +26,6 @@ from werkzeug.exceptions import Unauthorized from commands import register_commands # DO NOT REMOVE BELOW -from events import event_handlers from extensions import ( ext_celery, ext_code_based_extension, @@ -43,7 +42,6 @@ from extensions import ( from extensions.ext_database import db from extensions.ext_login import login_manager from libs.passport import PassportService -from models import account, dataset, model, source, task, tool, tools, web from services.account_service import AccountService # DO NOT REMOVE ABOVE diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 3482d5c5cf..c5dd88fb24 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -57,7 +57,7 @@ class InputModeration: timer=timer ) ) - + if not moderation_result.flagged: return False, inputs, query diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index b615f21e6c..db7e0806ee 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -94,5 +94,15 @@ class ToolTraceInfo(BaseTraceInfo): class GenerateNameTraceInfo(BaseTraceInfo): - conversation_id: str + conversation_id: Optional[str] = None tenant_id: str + +trace_info_info_map = { + 'WorkflowTraceInfo': WorkflowTraceInfo, + 'MessageTraceInfo': MessageTraceInfo, + 'ModerationTraceInfo': ModerationTraceInfo, + 'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo, + 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, + 'ToolTraceInfo': ToolTraceInfo, + 'GenerateNameTraceInfo': GenerateNameTraceInfo, +} \ No newline at end of file diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 05d34c5527..46795c8c3c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -147,6 +147,7 @@ class LangFuseDataTrace(BaseTraceInstance): # add span if trace_info.message_id: span_data = LangfuseSpan( + id=node_execution_id, name=f"{node_name}_{node_execution_id}", input=inputs, output=outputs, @@ -160,6 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) else: span_data = LangfuseSpan( + id=node_execution_id, name=f"{node_name}_{node_execution_id}", input=inputs, output=outputs, @@ -173,6 +175,30 @@ class LangFuseDataTrace(BaseTraceInstance): self.add_span(langfuse_span_data=span_data) + process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + if process_data and process_data.get("model_mode") == "chat": + total_token = metadata.get("total_tokens", 0) + # add generation + generation_usage = GenerationUsage( + totalTokens=total_token, + ) + + node_generation_data = LangfuseGeneration( + name=f"generation_{node_execution_id}", + trace_id=trace_id, + parent_observation_id=node_execution_id, + start_time=created_at, + end_time=finished_at, + input=inputs, + output=outputs, + metadata=metadata, + level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + status_message=trace_info.error if trace_info.error else "", + usage=generation_usage, + ) + + self.add_generation(langfuse_generation_data=node_generation_data) + def message_trace( self, trace_info: MessageTraceInfo, **kwargs ): @@ -186,7 +212,7 @@ class LangFuseDataTrace(BaseTraceInstance): if message_data.from_end_user_id: end_user_data: EndUser = db.session.query(EndUser).filter( EndUser.id == message_data.from_end_user_id - ).first().session_id + ).first() user_id = end_user_data.session_id trace_data = LangfuseTrace( @@ -220,6 +246,7 @@ class LangFuseDataTrace(BaseTraceInstance): output=trace_info.answer_tokens, total=trace_info.total_tokens, unit=UnitEnum.TOKENS, + totalCost=message_data.total_price, ) langfuse_generation_data = LangfuseGeneration( @@ -303,7 +330,7 @@ class LangFuseDataTrace(BaseTraceInstance): start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, - level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, + level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR, status_message=trace_info.error, ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 00750ab81f..2ce12f28d1 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -1,16 +1,17 @@ import json +import logging import os import queue import threading +import time from datetime import timedelta from enum import Enum from typing import Any, Optional, Union from uuid import UUID -from flask import Flask, current_app +from flask import current_app from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token -from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import ( LangfuseConfig, LangSmithConfig, @@ -31,6 +32,7 @@ from core.ops.utils import get_message_data from extensions.ext_database import db from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun +from tasks.ops_trace_task import process_trace_tasks provider_config_map = { TracingProviderEnum.LANGFUSE.value: { @@ -105,7 +107,7 @@ class OpsTraceManager: return config_class(**new_config).model_dump() @classmethod - def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict): + def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict): """ Decrypt tracing config :param tracing_provider: tracing provider @@ -295,11 +297,9 @@ class TraceTask: self.kwargs = kwargs self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - def execute(self, trace_instance: BaseTraceInstance): + def execute(self): method_name, trace_info = self.preprocess() - if trace_instance: - method = trace_instance.trace - method(trace_info) + return trace_info def preprocess(self): if self.trace_type == TraceTaskName.CONVERSATION_TRACE: @@ -372,7 +372,7 @@ class TraceTask: } workflow_trace_info = WorkflowTraceInfo( - workflow_data=workflow_run, + workflow_data=workflow_run.to_dict(), conversation_id=conversation_id, workflow_id=workflow_id, tenant_id=tenant_id, @@ -427,7 +427,8 @@ class TraceTask: message_tokens = message_data.message_tokens message_trace_info = MessageTraceInfo( - message_data=message_data, + message_id=message_id, + message_data=message_data.to_dict(), conversation_model=conversation_mode, message_tokens=message_tokens, answer_tokens=message_data.answer_tokens, @@ -469,7 +470,7 @@ class TraceTask: moderation_trace_info = ModerationTraceInfo( message_id=workflow_app_log_id if workflow_app_log_id else message_id, inputs=inputs, - message_data=message_data, + message_data=message_data.to_dict(), flagged=moderation_result.flagged, action=moderation_result.action, preset_response=moderation_result.preset_response, @@ -508,7 +509,7 @@ class TraceTask: suggested_question_trace_info = SuggestedQuestionTraceInfo( message_id=workflow_app_log_id if workflow_app_log_id else message_id, - message_data=message_data, + message_data=message_data.to_dict(), inputs=message_data.message, outputs=message_data.answer, start_time=timer.get("start"), @@ -550,11 +551,11 @@ class TraceTask: dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query if message_data.query else message_data.inputs, - documents=documents, + documents=[doc.model_dump() for doc in documents], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, - message_data=message_data, + message_data=message_data.to_dict(), ) return dataset_retrieval_trace_info @@ -613,7 +614,7 @@ class TraceTask: tool_trace_info = ToolTraceInfo( message_id=message_id, - message_data=message_data, + message_data=message_data.to_dict(), tool_name=tool_name, start_time=timer.get("start") if timer else created_time, end_time=timer.get("end") if timer else end_time, @@ -657,31 +658,71 @@ class TraceTask: return generate_name_trace_info +trace_manager_timer = None +trace_manager_queue = queue.Queue() +trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 1)) +trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) + + class TraceQueueManager: def __init__(self, app_id=None, conversation_id=None, message_id=None): - tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id) - self.queue = queue.Queue() - self.is_running = True - self.thread = threading.Thread( - target=self.process_queue, kwargs={ - 'flask_app': current_app._get_current_object(), - 'trace_instance': tracing_instance - } - ) - self.thread.start() + global trace_manager_timer - def stop(self): - self.is_running = False - - def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance): - with flask_app.app_context(): - while self.is_running: - try: - task = self.queue.get(timeout=60) - task.execute(trace_instance) - self.queue.task_done() - except queue.Empty: - self.stop() + self.app_id = app_id + self.conversation_id = conversation_id + self.message_id = message_id + self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id) + self.flask_app = current_app._get_current_object() + if trace_manager_timer is None: + self.start_timer() def add_trace_task(self, trace_task: TraceTask): - self.queue.put(trace_task) + global trace_manager_timer + global trace_manager_queue + try: + if self.trace_instance: + trace_manager_queue.put(trace_task) + except Exception as e: + logging.debug(f"Error adding trace task: {e}") + finally: + self.start_timer() + + def collect_tasks(self): + global trace_manager_queue + tasks = [] + while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty(): + task = trace_manager_queue.get_nowait() + tasks.append(task) + trace_manager_queue.task_done() + return tasks + + def run(self): + try: + tasks = self.collect_tasks() + if tasks: + self.send_to_celery(tasks) + except Exception as e: + logging.debug(f"Error processing trace tasks: {e}") + + def start_timer(self): + global trace_manager_timer + if trace_manager_timer is None or not trace_manager_timer.is_alive(): + trace_manager_timer = threading.Timer( + trace_manager_interval, self.run + ) + trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" + trace_manager_timer.daemon = False + trace_manager_timer.start() + + def send_to_celery(self, tasks: list[TraceTask]): + with self.flask_app.app_context(): + for task in tasks: + trace_info = task.execute() + task_data = { + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "message_id": self.message_id, + "trace_info_type": type(trace_info).__name__, + "trace_info": trace_info.model_dump() if trace_info else {}, + } + process_trace_tasks.delay(task_data) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8544d7c3c8..ea2a194a68 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -12,7 +12,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.ops_trace_manager import TraceTask, TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName from core.ops.utils import measure_time from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document @@ -357,7 +357,7 @@ class DatasetRetrieval: db.session.commit() # get tracing instance - trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None + trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None if trace_manager: trace_manager.add_trace_task( TraceTask( diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 386fa410aa..ea0cdf96e7 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -94,7 +94,7 @@ class ParameterExtractorNode(LLMNode): memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ - and node_data.reasoning_mode == 'function_call': + and node_data.reasoning_mode == 'function_call': # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( node_data, query, variable_pool, model_config, memory diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 7b1f1dfc03..0bb494abd7 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -9,7 +9,7 @@ fi if [[ "${MODE}" == "worker" ]]; then celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \ - -Q ${CELERY_QUEUES:-dataset,generation,mail} + -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace} elif [[ "${MODE}" == "beat" ]]; then celery -A app.celery beat --loglevel INFO else diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index a322b9f502..be2c615525 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -31,17 +31,11 @@ def upgrade(): with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('trace_config', sa.Text(), nullable=True)) - # ### end Alembic commands ### def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.drop_column('trace_config') - + # ### commands auto generated by Alembic - please adjust! ## with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: batch_op.drop_index('tracing_app_config_app_id_idx') diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 20d9c5d1fb..1ac44d083a 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -35,18 +35,11 @@ def upgrade(): with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: batch_op.drop_index('tracing_app_config_app_id_idx') - - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.drop_column('trace_config') - # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('app_model_configs', schema=None) as batch_op: - batch_op.add_column(sa.Column('trace_config', sa.TEXT(), autoincrement=False, nullable=True)) - op.create_table('tracing_app_configs', sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False), diff --git a/api/models/dataset.py b/api/models/dataset.py index 757a5bf8de..672c2be8fa 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -352,6 +352,101 @@ class Document(db.Model): return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ .filter(DocumentSegment.document_id == self.id).scalar() + def to_dict(self): + return { + 'id': self.id, + 'tenant_id': self.tenant_id, + 'dataset_id': self.dataset_id, + 'position': self.position, + 'data_source_type': self.data_source_type, + 'data_source_info': self.data_source_info, + 'dataset_process_rule_id': self.dataset_process_rule_id, + 'batch': self.batch, + 'name': self.name, + 'created_from': self.created_from, + 'created_by': self.created_by, + 'created_api_request_id': self.created_api_request_id, + 'created_at': self.created_at, + 'processing_started_at': self.processing_started_at, + 'file_id': self.file_id, + 'word_count': self.word_count, + 'parsing_completed_at': self.parsing_completed_at, + 'cleaning_completed_at': self.cleaning_completed_at, + 'splitting_completed_at': self.splitting_completed_at, + 'tokens': self.tokens, + 'indexing_latency': self.indexing_latency, + 'completed_at': self.completed_at, + 'is_paused': self.is_paused, + 'paused_by': self.paused_by, + 'paused_at': self.paused_at, + 'error': self.error, + 'stopped_at': self.stopped_at, + 'indexing_status': self.indexing_status, + 'enabled': self.enabled, + 'disabled_at': self.disabled_at, + 'disabled_by': self.disabled_by, + 'archived': self.archived, + 'archived_reason': self.archived_reason, + 'archived_by': self.archived_by, + 'archived_at': self.archived_at, + 'updated_at': self.updated_at, + 'doc_type': self.doc_type, + 'doc_metadata': self.doc_metadata, + 'doc_form': self.doc_form, + 'doc_language': self.doc_language, + 'display_status': self.display_status, + 'data_source_info_dict': self.data_source_info_dict, + 'average_segment_length': self.average_segment_length, + 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, + 'dataset': self.dataset.to_dict() if self.dataset else None, + 'segment_count': self.segment_count, + 'hit_count': self.hit_count + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data.get('id'), + tenant_id=data.get('tenant_id'), + dataset_id=data.get('dataset_id'), + position=data.get('position'), + data_source_type=data.get('data_source_type'), + data_source_info=data.get('data_source_info'), + dataset_process_rule_id=data.get('dataset_process_rule_id'), + batch=data.get('batch'), + name=data.get('name'), + created_from=data.get('created_from'), + created_by=data.get('created_by'), + created_api_request_id=data.get('created_api_request_id'), + created_at=data.get('created_at'), + processing_started_at=data.get('processing_started_at'), + file_id=data.get('file_id'), + word_count=data.get('word_count'), + parsing_completed_at=data.get('parsing_completed_at'), + cleaning_completed_at=data.get('cleaning_completed_at'), + splitting_completed_at=data.get('splitting_completed_at'), + tokens=data.get('tokens'), + indexing_latency=data.get('indexing_latency'), + completed_at=data.get('completed_at'), + is_paused=data.get('is_paused'), + paused_by=data.get('paused_by'), + paused_at=data.get('paused_at'), + error=data.get('error'), + stopped_at=data.get('stopped_at'), + indexing_status=data.get('indexing_status'), + enabled=data.get('enabled'), + disabled_at=data.get('disabled_at'), + disabled_by=data.get('disabled_by'), + archived=data.get('archived'), + archived_reason=data.get('archived_reason'), + archived_by=data.get('archived_by'), + archived_at=data.get('archived_at'), + updated_at=data.get('updated_at'), + doc_type=data.get('doc_type'), + doc_metadata=data.get('doc_metadata'), + doc_form=data.get('doc_form'), + doc_language=data.get('doc_language') + ) class DocumentSegment(db.Model): __tablename__ = 'document_segments' diff --git a/api/models/model.py b/api/models/model.py index ecb89861db..07d7f6d891 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -838,6 +838,49 @@ class Message(db.Model): return None + def to_dict(self) -> dict: + return { + 'id': self.id, + 'app_id': self.app_id, + 'conversation_id': self.conversation_id, + 'inputs': self.inputs, + 'query': self.query, + 'message': self.message, + 'answer': self.answer, + 'status': self.status, + 'error': self.error, + 'message_metadata': self.message_metadata_dict, + 'from_source': self.from_source, + 'from_end_user_id': self.from_end_user_id, + 'from_account_id': self.from_account_id, + 'created_at': self.created_at.isoformat(), + 'updated_at': self.updated_at.isoformat(), + 'agent_based': self.agent_based, + 'workflow_run_id': self.workflow_run_id + } + + @classmethod + def from_dict(cls, data: dict): + return cls( + id=data['id'], + app_id=data['app_id'], + conversation_id=data['conversation_id'], + inputs=data['inputs'], + query=data['query'], + message=data['message'], + answer=data['answer'], + status=data['status'], + error=data['error'], + message_metadata=json.dumps(data['message_metadata']), + from_source=data['from_source'], + from_end_user_id=data['from_end_user_id'], + from_account_id=data['from_account_id'], + created_at=data['created_at'], + updated_at=data['updated_at'], + agent_based=data['agent_based'], + workflow_run_id=data['workflow_run_id'] + ) + class MessageFeedback(db.Model): __tablename__ = 'message_feedbacks' diff --git a/api/models/workflow.py b/api/models/workflow.py index d9bc784878..2d6491032b 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -324,6 +324,55 @@ class WorkflowRun(db.Model): def workflow(self): return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() + def to_dict(self): + return { + 'id': self.id, + 'tenant_id': self.tenant_id, + 'app_id': self.app_id, + 'sequence_number': self.sequence_number, + 'workflow_id': self.workflow_id, + 'type': self.type, + 'triggered_from': self.triggered_from, + 'version': self.version, + 'graph': self.graph_dict, + 'inputs': self.inputs_dict, + 'status': self.status, + 'outputs': self.outputs_dict, + 'error': self.error, + 'elapsed_time': self.elapsed_time, + 'total_tokens': self.total_tokens, + 'total_steps': self.total_steps, + 'created_by_role': self.created_by_role, + 'created_by': self.created_by, + 'created_at': self.created_at, + 'finished_at': self.finished_at, + } + + @classmethod + def from_dict(cls, data: dict) -> 'WorkflowRun': + return cls( + id=data.get('id'), + tenant_id=data.get('tenant_id'), + app_id=data.get('app_id'), + sequence_number=data.get('sequence_number'), + workflow_id=data.get('workflow_id'), + type=data.get('type'), + triggered_from=data.get('triggered_from'), + version=data.get('version'), + graph=json.dumps(data.get('graph')), + inputs=json.dumps(data.get('inputs')), + status=data.get('status'), + outputs=json.dumps(data.get('outputs')), + error=data.get('error'), + elapsed_time=data.get('elapsed_time'), + total_tokens=data.get('total_tokens'), + total_steps=data.get('total_steps'), + created_by_role=data.get('created_by_role'), + created_by=data.get('created_by'), + created_at=data.get('created_at'), + finished_at=data.get('finished_at'), + ) + class WorkflowNodeExecutionTriggeredFrom(Enum): """ diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py new file mode 100644 index 0000000000..1d33609205 --- /dev/null +++ b/api/tasks/ops_trace_task.py @@ -0,0 +1,46 @@ +import logging +import time + +from celery import shared_task +from flask import current_app + +from core.ops.entities.trace_entity import trace_info_info_map +from core.rag.models.document import Document +from models.model import Message +from models.workflow import WorkflowRun + + +@shared_task(queue='ops_trace') +def process_trace_tasks(tasks_data): + """ + Async process trace tasks + :param tasks_data: List of dictionaries containing task data + + Usage: process_trace_tasks.delay(tasks_data) + """ + from core.ops.ops_trace_manager import OpsTraceManager + + trace_info = tasks_data.get('trace_info') + app_id = tasks_data.get('app_id') + conversation_id = tasks_data.get('conversation_id') + message_id = tasks_data.get('message_id') + trace_info_type = tasks_data.get('trace_info_type') + trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id) + + if trace_info.get('message_data'): + trace_info['message_data'] = Message.from_dict(data=trace_info['message_data']) + if trace_info.get('workflow_data'): + trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data']) + if trace_info.get('documents'): + trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']] + + try: + if trace_instance: + with current_app.app_context(): + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + trace_instance.trace(trace_info) + end_at = time.perf_counter() + except Exception: + logging.exception("Processing trace tasks failed")