mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 04:25:59 +08:00
refactor: Remove RepositoryFactory (#19176)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
a6827493f0
commit
f23cf98317
@ -54,7 +54,6 @@ def initialize_extensions(app: DifyApp):
|
|||||||
ext_otel,
|
ext_otel,
|
||||||
ext_proxy_fix,
|
ext_proxy_fix,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
ext_repositories,
|
|
||||||
ext_sentry,
|
ext_sentry,
|
||||||
ext_set_secretkey,
|
ext_set_secretkey,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
@ -75,7 +74,6 @@ def initialize_extensions(app: DifyApp):
|
|||||||
ext_migrate,
|
ext_migrate,
|
||||||
ext_redis,
|
ext_redis,
|
||||||
ext_storage,
|
ext_storage,
|
||||||
ext_repositories,
|
|
||||||
ext_celery,
|
ext_celery,
|
||||||
ext_login,
|
ext_login,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
|
@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
|||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||||
from core.workflow.repository import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
@ -163,12 +163,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory,
|
||||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
"app_id": application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
"session_factory": session_factory,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -231,12 +229,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory,
|
||||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
"app_id": application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
"session_factory": session_factory,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -297,12 +293,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory,
|
||||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
"app_id": application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
"session_factory": session_factory,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
|
@ -9,7 +9,6 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
@ -58,7 +57,7 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
@ -66,6 +65,7 @@ from core.workflow.enums import SystemVariableKey
|
|||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Conversation, EndUser, Message, MessageFile
|
from models import Conversation, EndUser, Message, MessageFile
|
||||||
@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||||
|
|
||||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
workflow_system_variables={
|
workflow_system_variables={
|
||||||
SystemVariableKey.QUERY: message.query,
|
SystemVariableKey.QUERY: message.query,
|
||||||
|
@ -18,13 +18,13 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
|||||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.repository import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models import Account, App, EndUser, Workflow
|
from models import Account, App, EndUser, Workflow
|
||||||
@ -138,12 +138,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory,
|
||||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
"app_id": application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
"session_factory": session_factory,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -264,12 +262,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory,
|
||||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
"app_id": application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
"session_factory": session_factory,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
@ -329,12 +325,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
|
|
||||||
# Create workflow node execution repository
|
# Create workflow node execution repository
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory,
|
||||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
"app_id": application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
"session_factory": session_factory,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._generate(
|
return self._generate(
|
||||||
|
@ -9,7 +9,6 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AgentChatAppGenerateEntity,
|
AgentChatAppGenerateEntity,
|
||||||
@ -45,6 +44,7 @@ from core.app.entities.task_entities import (
|
|||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||||
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
1
api/core/base/__init__.py
Normal file
1
api/core/base/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Core base package
|
6
api/core/base/tts/__init__.py
Normal file
6
api/core/base/tts/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from core.base.tts.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppGeneratorTTSPublisher",
|
||||||
|
"AudioTrunk",
|
||||||
|
]
|
@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||||||
UnitEnum,
|
UnitEnum,
|
||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values
|
from core.ops.utils import filter_none_values
|
||||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
|
||||||
@ -113,8 +113,8 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution using repository
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
|
session_factory=session_factory, tenant_id=trace_info.tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all executions for this workflow run
|
# Get all executions for this workflow run
|
||||||
|
@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||||||
LangSmithRunUpdateModel,
|
LangSmithRunUpdateModel,
|
||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models.model import EndUser, MessageFile
|
||||||
|
|
||||||
@ -137,12 +137,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution using repository
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
||||||
"tenant_id": trace_info.tenant_id,
|
|
||||||
"app_id": trace_info.metadata.get("app_id"),
|
|
||||||
"session_factory": session_factory,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all executions for this workflow run
|
# Get all executions for this workflow run
|
||||||
|
@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
|
|||||||
TraceTaskName,
|
TraceTaskName,
|
||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import EndUser, MessageFile
|
from models.model import EndUser, MessageFile
|
||||||
|
|
||||||
@ -150,12 +150,8 @@ class OpikDataTrace(BaseTraceInstance):
|
|||||||
|
|
||||||
# through workflow_run_id get all_nodes_execution using repository
|
# through workflow_run_id get all_nodes_execution using repository
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
||||||
"tenant_id": trace_info.tenant_id,
|
|
||||||
"app_id": trace_info.metadata.get("app_id"),
|
|
||||||
"session_factory": session_factory,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all executions for this workflow run
|
# Get all executions for this workflow run
|
||||||
|
@ -4,3 +4,9 @@ Repository implementations for data access.
|
|||||||
This package contains concrete implementations of the repository interfaces
|
This package contains concrete implementations of the repository interfaces
|
||||||
defined in the core.workflow.repository package.
|
defined in the core.workflow.repository package.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||||
|
]
|
||||||
|
@ -1,87 +0,0 @@
|
|||||||
"""
|
|
||||||
Registry for repository implementations.
|
|
||||||
|
|
||||||
This module is responsible for registering factory functions with the repository factory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
|
||||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Storage type constants
|
|
||||||
STORAGE_TYPE_RDBMS = "rdbms"
|
|
||||||
STORAGE_TYPE_HYBRID = "hybrid"
|
|
||||||
|
|
||||||
|
|
||||||
def register_repositories() -> None:
|
|
||||||
"""
|
|
||||||
Register repository factory functions with the RepositoryFactory.
|
|
||||||
|
|
||||||
This function reads configuration settings to determine which repository
|
|
||||||
implementations to register.
|
|
||||||
"""
|
|
||||||
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
|
||||||
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
|
||||||
|
|
||||||
# Check storage type and register appropriate implementation
|
|
||||||
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
|
||||||
# Register SQLAlchemy implementation for RDBMS storage
|
|
||||||
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
|
||||||
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
|
||||||
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
|
||||||
# Hybrid storage is not yet implemented
|
|
||||||
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
|
||||||
else:
|
|
||||||
# Unknown storage type
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
|
||||||
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
|
||||||
"""
|
|
||||||
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
|
||||||
|
|
||||||
This factory function creates a repository for the RDBMS storage type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: Parameters for creating the repository, including:
|
|
||||||
- tenant_id: Required. The tenant ID for multi-tenancy.
|
|
||||||
- app_id: Optional. The application ID for filtering.
|
|
||||||
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
|
||||||
a new sessionmaker will be created using the global database engine.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A WorkflowNodeExecutionRepository instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If required parameters are missing
|
|
||||||
"""
|
|
||||||
# Extract required parameters
|
|
||||||
tenant_id = params.get("tenant_id")
|
|
||||||
if tenant_id is None:
|
|
||||||
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
|
||||||
|
|
||||||
# Extract optional parameters
|
|
||||||
app_id = params.get("app_id")
|
|
||||||
|
|
||||||
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
|
||||||
session_factory = params.get("session_factory")
|
|
||||||
if session_factory is None:
|
|
||||||
# Create a sessionmaker using the same engine as the global db session
|
|
||||||
session_factory = sessionmaker(bind=db.engine)
|
|
||||||
|
|
||||||
# Create and return the repository
|
|
||||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
|
||||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
|
||||||
)
|
|
@ -10,13 +10,13 @@ from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
|||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SQLAlchemyWorkflowNodeExecutionRepository:
|
class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||||
"""
|
"""
|
||||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||||
|
|
@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
WorkflowNodeExecution repository implementations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
|
||||||
]
|
|
@ -6,10 +6,9 @@ for accessing and manipulating data, regardless of the underlying
|
|||||||
storage mechanism.
|
storage mechanism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RepositoryFactory",
|
"OrderConfig",
|
||||||
"WorkflowNodeExecutionRepository",
|
"WorkflowNodeExecutionRepository",
|
||||||
]
|
]
|
||||||
|
@ -1,97 +0,0 @@
|
|||||||
"""
|
|
||||||
Repository factory for creating repository instances.
|
|
||||||
|
|
||||||
This module provides a simple factory interface for creating repository instances.
|
|
||||||
It does not contain any implementation details or dependencies on specific repositories.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable, Mapping
|
|
||||||
from typing import Any, Literal, Optional, cast
|
|
||||||
|
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
|
||||||
|
|
||||||
# Type for factory functions - takes a dict of parameters and returns any repository type
|
|
||||||
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
|
||||||
|
|
||||||
# Type for workflow node execution factory function
|
|
||||||
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
|
||||||
|
|
||||||
# Repository type literals
|
|
||||||
_RepositoryType = Literal["workflow_node_execution"]
|
|
||||||
|
|
||||||
|
|
||||||
class RepositoryFactory:
|
|
||||||
"""
|
|
||||||
Factory class for creating repository instances.
|
|
||||||
|
|
||||||
This factory delegates the actual repository creation to implementation-specific
|
|
||||||
factory functions that are registered with the factory at runtime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Dictionary to store factory functions
|
|
||||||
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
|
||||||
"""
|
|
||||||
Register a factory function for a specific repository type.
|
|
||||||
This is a private method and should not be called directly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
|
||||||
factory_func: A function that takes parameters and returns a repository instance
|
|
||||||
"""
|
|
||||||
cls._factory_functions[repository_type] = factory_func
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
|
||||||
"""
|
|
||||||
Create a new repository instance with the provided parameters.
|
|
||||||
This is a private method and should not be called directly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repository_type: The type of repository to create
|
|
||||||
params: A dictionary of parameters to pass to the factory function
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new instance of the requested repository
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no factory function is registered for the repository type
|
|
||||||
"""
|
|
||||||
if repository_type not in cls._factory_functions:
|
|
||||||
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
|
||||||
|
|
||||||
# Use empty dict if params is None
|
|
||||||
params = params or {}
|
|
||||||
|
|
||||||
return cls._factory_functions[repository_type](params)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
|
||||||
"""
|
|
||||||
Register a factory function for the workflow node execution repository.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
|
||||||
"""
|
|
||||||
cls._register_factory("workflow_node_execution", factory_func)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_workflow_node_execution_repository(
|
|
||||||
cls, params: Optional[Mapping[str, Any]] = None
|
|
||||||
) -> WorkflowNodeExecutionRepository:
|
|
||||||
"""
|
|
||||||
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params: A dictionary of parameters to pass to the factory function
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new instance of the WorkflowNodeExecutionRepository
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
|
||||||
"""
|
|
||||||
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
|
||||||
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
|
@ -6,7 +6,6 @@ from typing import Optional, Union
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
@ -52,10 +51,11 @@ from core.app.entities.task_entities import (
|
|||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatedByRole
|
||||||
@ -102,7 +102,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid user type: {type(user)}")
|
raise ValueError(f"Invalid user type: {type(user)}")
|
||||||
|
|
||||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
workflow_system_variables={
|
workflow_system_variables={
|
||||||
SystemVariableKey.FILES: application_generate_entity.files,
|
SystemVariableKey.FILES: application_generate_entity.files,
|
@ -69,7 +69,7 @@ from models.workflow import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCycleManage:
|
class WorkflowCycleManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
Extension for initializing repositories.
|
|
||||||
|
|
||||||
This extension registers repository implementations with the RepositoryFactory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from core.repositories.repository_registry import register_repositories
|
|
||||||
from dify_app import DifyApp
|
|
||||||
|
|
||||||
|
|
||||||
def init_app(_app: DifyApp) -> None:
|
|
||||||
"""
|
|
||||||
Initialize repository implementations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_app: The Flask application instance (unused)
|
|
||||||
"""
|
|
||||||
register_repositories()
|
|
@ -2,7 +2,7 @@ import threading
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.workflow.repository import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
@ -129,12 +129,8 @@ class WorkflowRunService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Use the repository to get the node executions
|
# Use the repository to get the node executions
|
||||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
|
||||||
"tenant_id": app_model.tenant_id,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"session_factory": db.session.get_bind(),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the repository to get the node executions with ordering
|
# Use the repository to get the node executions with ordering
|
||||||
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.variables import Variable
|
from core.variables import Variable
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
@ -21,7 +22,6 @@ from core.workflow.nodes.enums import ErrorStrategy
|
|||||||
from core.workflow.nodes.event import RunCompletedEvent
|
from core.workflow.nodes.event import RunCompletedEvent
|
||||||
from core.workflow.nodes.event.types import NodeEvent
|
from core.workflow.nodes.event.types import NodeEvent
|
||||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.repository import RepositoryFactory
|
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -285,12 +285,8 @@ class WorkflowService:
|
|||||||
workflow_node_execution.workflow_id = draft_workflow.id
|
workflow_node_execution.workflow_id = draft_workflow.id
|
||||||
|
|
||||||
# Use the repository to save the workflow node execution
|
# Use the repository to save the workflow node execution
|
||||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
|
||||||
"tenant_id": app_model.tenant_id,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"session_factory": db.session.get_bind(),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
repository.save(workflow_node_execution)
|
repository.save(workflow_node_execution)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from celery import shared_task # type: ignore
|
|||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from core.workflow.repository import RepositoryFactory
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import AppDatasetJoin
|
from models.dataset import AppDatasetJoin
|
||||||
from models.model import (
|
from models.model import (
|
||||||
@ -189,12 +189,8 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
|
|||||||
|
|
||||||
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||||
# Create a repository instance for WorkflowNodeExecution
|
# Create a repository instance for WorkflowNodeExecution
|
||||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
params={
|
session_factory=db.engine, tenant_id=tenant_id, app_id=app_id
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"app_id": app_id,
|
|
||||||
"session_factory": db.session.get_bind(),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the clear method to delete all records for this tenant_id and app_id
|
# Use the clear method to delete all records for this tenant_id and app_id
|
||||||
|
@ -0,0 +1,348 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
QueueNodeFailedEvent,
|
||||||
|
QueueNodeStartedEvent,
|
||||||
|
QueueNodeSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||||
|
from models.enums import CreatedByRole
|
||||||
|
from models.workflow import (
|
||||||
|
Workflow,
|
||||||
|
WorkflowNodeExecution,
|
||||||
|
WorkflowNodeExecutionStatus,
|
||||||
|
WorkflowRun,
|
||||||
|
WorkflowRunStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_app_generate_entity():
|
||||||
|
entity = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||||
|
entity.inputs = {"query": "test query"}
|
||||||
|
entity.invoke_from = InvokeFrom.WEB_APP
|
||||||
|
# Create app_config as a separate mock
|
||||||
|
app_config = MagicMock()
|
||||||
|
app_config.tenant_id = "test-tenant-id"
|
||||||
|
app_config.app_id = "test-app-id"
|
||||||
|
entity.app_config = app_config
|
||||||
|
return entity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_workflow_system_variables():
|
||||||
|
return {
|
||||||
|
SystemVariableKey.QUERY: "test query",
|
||||||
|
SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
|
||||||
|
SystemVariableKey.USER_ID: "test-user-id",
|
||||||
|
SystemVariableKey.APP_ID: "test-app-id",
|
||||||
|
SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
|
||||||
|
SystemVariableKey.WORKFLOW_RUN_ID: "test-workflow-run-id",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_node_execution_repository():
|
||||||
|
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||||
|
repo.get_by_node_execution_id.return_value = None
|
||||||
|
repo.get_running_executions.return_value = []
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository):
|
||||||
|
return WorkflowCycleManager(
|
||||||
|
application_generate_entity=mock_app_generate_entity,
|
||||||
|
workflow_system_variables=mock_workflow_system_variables,
|
||||||
|
workflow_node_execution_repository=mock_node_execution_repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session():
|
||||||
|
session = MagicMock(spec=Session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_workflow():
|
||||||
|
workflow = MagicMock(spec=Workflow)
|
||||||
|
workflow.id = "test-workflow-id"
|
||||||
|
workflow.tenant_id = "test-tenant-id"
|
||||||
|
workflow.app_id = "test-app-id"
|
||||||
|
workflow.type = "chat"
|
||||||
|
workflow.version = "1.0"
|
||||||
|
workflow.graph = json.dumps({"nodes": [], "edges": []})
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_workflow_run():
|
||||||
|
workflow_run = MagicMock(spec=WorkflowRun)
|
||||||
|
workflow_run.id = "test-workflow-run-id"
|
||||||
|
workflow_run.tenant_id = "test-tenant-id"
|
||||||
|
workflow_run.app_id = "test-app-id"
|
||||||
|
workflow_run.workflow_id = "test-workflow-id"
|
||||||
|
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||||
|
workflow_run.created_by_role = CreatedByRole.ACCOUNT
|
||||||
|
workflow_run.created_by = "test-user-id"
|
||||||
|
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
workflow_run.inputs_dict = {"query": "test query"}
|
||||||
|
workflow_run.outputs_dict = {"answer": "test answer"}
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
|
|
||||||
|
def test_init(
|
||||||
|
workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository
|
||||||
|
):
|
||||||
|
"""Test initialization of WorkflowCycleManager"""
|
||||||
|
assert workflow_cycle_manager._workflow_run is None
|
||||||
|
assert workflow_cycle_manager._workflow_node_executions == {}
|
||||||
|
assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity
|
||||||
|
assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables
|
||||||
|
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow):
|
||||||
|
"""Test _handle_workflow_run_start method"""
|
||||||
|
# Mock session.scalar to return the workflow and max sequence
|
||||||
|
mock_session.scalar.side_effect = [mock_workflow, 5]
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
workflow_run = workflow_cycle_manager._handle_workflow_run_start(
|
||||||
|
session=mock_session,
|
||||||
|
workflow_id="test-workflow-id",
|
||||||
|
user_id="test-user-id",
|
||||||
|
created_by_role=CreatedByRole.ACCOUNT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert workflow_run.tenant_id == mock_workflow.tenant_id
|
||||||
|
assert workflow_run.app_id == mock_workflow.app_id
|
||||||
|
assert workflow_run.workflow_id == mock_workflow.id
|
||||||
|
assert workflow_run.sequence_number == 6 # max_sequence + 1
|
||||||
|
assert workflow_run.status == WorkflowRunStatus.RUNNING
|
||||||
|
assert workflow_run.created_by_role == CreatedByRole.ACCOUNT
|
||||||
|
assert workflow_run.created_by == "test-user-id"
|
||||||
|
|
||||||
|
# Verify session.add was called
|
||||||
|
mock_session.add.assert_called_once_with(workflow_run)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run):
|
||||||
|
"""Test _handle_workflow_run_success method"""
|
||||||
|
# Mock _get_workflow_run to return the mock_workflow_run
|
||||||
|
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._handle_workflow_run_success(
|
||||||
|
session=mock_session,
|
||||||
|
workflow_run_id="test-workflow-run-id",
|
||||||
|
start_at=time.perf_counter() - 10, # 10 seconds ago
|
||||||
|
total_tokens=100,
|
||||||
|
total_steps=5,
|
||||||
|
outputs={"answer": "test answer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == mock_workflow_run
|
||||||
|
assert result.status == WorkflowRunStatus.SUCCEEDED
|
||||||
|
assert result.outputs == json.dumps({"answer": "test answer"})
|
||||||
|
assert result.total_tokens == 100
|
||||||
|
assert result.total_steps == 5
|
||||||
|
assert result.finished_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run):
|
||||||
|
"""Test _handle_workflow_run_failed method"""
|
||||||
|
# Mock _get_workflow_run to return the mock_workflow_run
|
||||||
|
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
|
||||||
|
# Mock get_running_executions to return an empty list
|
||||||
|
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._handle_workflow_run_failed(
|
||||||
|
session=mock_session,
|
||||||
|
workflow_run_id="test-workflow-run-id",
|
||||||
|
start_at=time.perf_counter() - 10, # 10 seconds ago
|
||||||
|
total_tokens=50,
|
||||||
|
total_steps=3,
|
||||||
|
status=WorkflowRunStatus.FAILED,
|
||||||
|
error="Test error message",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == mock_workflow_run
|
||||||
|
assert result.status == WorkflowRunStatus.FAILED.value
|
||||||
|
assert result.error == "Test error message"
|
||||||
|
assert result.total_tokens == 50
|
||||||
|
assert result.total_steps == 3
|
||||||
|
assert result.finished_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
|
||||||
|
"""Test _handle_node_execution_start method"""
|
||||||
|
# Create a mock event
|
||||||
|
event = MagicMock(spec=QueueNodeStartedEvent)
|
||||||
|
event.node_execution_id = "test-node-execution-id"
|
||||||
|
event.node_id = "test-node-id"
|
||||||
|
event.node_type = NodeType.LLM
|
||||||
|
|
||||||
|
# Create node_data as a separate mock
|
||||||
|
node_data = MagicMock()
|
||||||
|
node_data.title = "Test Node"
|
||||||
|
event.node_data = node_data
|
||||||
|
|
||||||
|
event.predecessor_node_id = "test-predecessor-node-id"
|
||||||
|
event.node_run_index = 1
|
||||||
|
event.parallel_mode_run_id = "test-parallel-mode-run-id"
|
||||||
|
event.in_iteration_id = "test-iteration-id"
|
||||||
|
event.in_loop_id = "test-loop-id"
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._handle_node_execution_start(
|
||||||
|
workflow_run=mock_workflow_run,
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result.tenant_id == mock_workflow_run.tenant_id
|
||||||
|
assert result.app_id == mock_workflow_run.app_id
|
||||||
|
assert result.workflow_id == mock_workflow_run.workflow_id
|
||||||
|
assert result.workflow_run_id == mock_workflow_run.id
|
||||||
|
assert result.node_execution_id == event.node_execution_id
|
||||||
|
assert result.node_id == event.node_id
|
||||||
|
assert result.node_type == event.node_type.value
|
||||||
|
assert result.title == event.node_data.title
|
||||||
|
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||||
|
assert result.created_by_role == mock_workflow_run.created_by_role
|
||||||
|
assert result.created_by == mock_workflow_run.created_by
|
||||||
|
|
||||||
|
# Verify save was called
|
||||||
|
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
|
||||||
|
|
||||||
|
# Verify the node execution was added to the cache
|
||||||
|
assert workflow_cycle_manager._workflow_node_executions[event.node_execution_id] == result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run):
|
||||||
|
"""Test _get_workflow_run method"""
|
||||||
|
# Mock session.scalar to return the workflow run
|
||||||
|
mock_session.scalar.return_value = mock_workflow_run
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._get_workflow_run(
|
||||||
|
session=mock_session,
|
||||||
|
workflow_run_id="test-workflow-run-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == mock_workflow_run
|
||||||
|
assert workflow_cycle_manager._workflow_run == mock_workflow_run
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
||||||
|
"""Test _handle_workflow_node_execution_success method"""
|
||||||
|
# Create a mock event
|
||||||
|
event = MagicMock(spec=QueueNodeSucceededEvent)
|
||||||
|
event.node_execution_id = "test-node-execution-id"
|
||||||
|
event.inputs = {"input": "test input"}
|
||||||
|
event.process_data = {"process": "test process"}
|
||||||
|
event.outputs = {"output": "test output"}
|
||||||
|
event.execution_metadata = {"metadata": "test metadata"}
|
||||||
|
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
# Create a mock workflow node execution
|
||||||
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
node_execution.node_execution_id = "test-node-execution-id"
|
||||||
|
|
||||||
|
# Mock _get_workflow_node_execution to return the mock node execution
|
||||||
|
with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution):
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == node_execution
|
||||||
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
|
assert result.inputs == json.dumps(event.inputs)
|
||||||
|
assert result.process_data == json.dumps(event.process_data)
|
||||||
|
assert result.outputs == json.dumps(event.outputs)
|
||||||
|
assert result.finished_at is not None
|
||||||
|
assert result.elapsed_time is not None
|
||||||
|
|
||||||
|
# Verify update was called
|
||||||
|
workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run):
|
||||||
|
"""Test _handle_workflow_run_partial_success method"""
|
||||||
|
# Mock _get_workflow_run to return the mock_workflow_run
|
||||||
|
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||||
|
session=mock_session,
|
||||||
|
workflow_run_id="test-workflow-run-id",
|
||||||
|
start_at=time.perf_counter() - 10, # 10 seconds ago
|
||||||
|
total_tokens=75,
|
||||||
|
total_steps=4,
|
||||||
|
outputs={"partial_answer": "test partial answer"},
|
||||||
|
exceptions_count=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == mock_workflow_run
|
||||||
|
assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value
|
||||||
|
assert result.outputs == json.dumps({"partial_answer": "test partial answer"})
|
||||||
|
assert result.total_tokens == 75
|
||||||
|
assert result.total_steps == 4
|
||||||
|
assert result.exceptions_count == 2
|
||||||
|
assert result.finished_at is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
||||||
|
"""Test _handle_workflow_node_execution_failed method"""
|
||||||
|
# Create a mock event
|
||||||
|
event = MagicMock(spec=QueueNodeFailedEvent)
|
||||||
|
event.node_execution_id = "test-node-execution-id"
|
||||||
|
event.inputs = {"input": "test input"}
|
||||||
|
event.process_data = {"process": "test process"}
|
||||||
|
event.outputs = {"output": "test output"}
|
||||||
|
event.execution_metadata = {"metadata": "test metadata"}
|
||||||
|
event.start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
event.error = "Test error message"
|
||||||
|
|
||||||
|
# Create a mock workflow node execution
|
||||||
|
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
node_execution.node_execution_id = "test-node-execution-id"
|
||||||
|
|
||||||
|
# Mock _get_workflow_node_execution to return the mock node execution
|
||||||
|
with patch.object(workflow_cycle_manager, "_get_workflow_node_execution", return_value=node_execution):
|
||||||
|
# Call the method
|
||||||
|
result = workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||||
|
event=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == node_execution
|
||||||
|
assert result.status == WorkflowNodeExecutionStatus.FAILED.value
|
||||||
|
assert result.error == "Test error message"
|
||||||
|
assert result.inputs == json.dumps(event.inputs)
|
||||||
|
assert result.process_data == json.dumps(event.process_data)
|
||||||
|
assert result.outputs == json.dumps(event.outputs)
|
||||||
|
assert result.finished_at is not None
|
||||||
|
assert result.elapsed_time is not None
|
||||||
|
assert result.execution_metadata == json.dumps(event.execution_metadata)
|
||||||
|
|
||||||
|
# Verify update was called
|
||||||
|
workflow_cycle_manager._workflow_node_execution_repository.update.assert_called_once_with(node_execution)
|
@ -8,7 +8,7 @@ import pytest
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||||
from models.workflow import WorkflowNodeExecution
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
|||||||
"""Test get_by_node_execution_id method."""
|
"""Test get_by_node_execution_id method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Set up mock
|
# Set up mock
|
||||||
mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select")
|
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
@ -99,7 +99,7 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
"""Test get_by_workflow_run method."""
|
"""Test get_by_workflow_run method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Set up mock
|
# Set up mock
|
||||||
mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select")
|
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
@ -120,7 +120,7 @@ def test_get_running_executions(repository, session, mocker: MockerFixture):
|
|||||||
"""Test get_running_executions method."""
|
"""Test get_running_executions method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Set up mock
|
# Set up mock
|
||||||
mock_select = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.select")
|
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
@ -158,7 +158,7 @@ def test_clear(repository, session, mocker: MockerFixture):
|
|||||||
"""Test clear method."""
|
"""Test clear method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Set up mock
|
# Set up mock
|
||||||
mock_delete = mocker.patch("core.repositories.workflow_node_execution.sqlalchemy_repository.delete")
|
mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
|
||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_delete.return_value = mock_stmt
|
mock_delete.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
Loading…
x
Reference in New Issue
Block a user