diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index bc2032c2a1..c9d3ad9ea3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -33,7 +33,8 @@ logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -120,6 +121,65 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, stream=stream ) + + def single_iteration_generate(self, app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: dict, + stream: bool = True) \ + -> dict[str, Any] | Generator[str, Any, None]: + """ + Generate App response. + + :param app_model: App + :param workflow: Workflow + :param user: account or end user + :param args: request args + :param invoke_from: invoke from source + :param stream: is stream + """ + if not node_id: + raise ValueError('node_id is required') + + if args.get('inputs') is None: + raise ValueError('inputs is required') + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow + ) + + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + conversation_id=None, + inputs={}, + query='', + files=[], + user_id=user.id, + stream=stream, + invoke_from=InvokeFrom.DEBUGGER, + extras={ + "auto_generate_conversation_name": False + }, + single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( + node_id=node_id, + inputs=args['inputs'] + ) + ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + + return self._generate( + workflow=workflow, + user=user, + invoke_from=InvokeFrom.DEBUGGER, + application_generate_entity=application_generate_entity, + conversation=None, + stream=stream + ) def _generate(self, *, workflow: Workflow, @@ -129,6 +189,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation: Optional[Conversation] = None, stream: bool = True) \ -> dict[str, Any] | Generator[str, Any, None]: + """ + Generate App response. + + :param workflow: Workflow + :param user: account or end user + :param invoke_from: invoke from source + :param application_generate_entity: application generate entity + :param conversation: conversation + :param stream: is stream + """ is_first_conversation = False if not conversation: is_first_conversation = True diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index f3eb23f810..be04ed29b0 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,14 +1,14 @@ import logging import os from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, @@ -17,52 +17,22 @@ from core.app.entities.app_invoke_entities import ( from core.app.entities.queue_entities import ( AppQueueEvent, QueueAnnotationReplyEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueParallelBranchRunFailedEvent, - QueueParallelBranchRunStartedEvent, - QueueRetrieverResourcesEvent, QueueStopEvent, QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, ) from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import ( - GraphEngineEvent, - GraphRunFailedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - IterationRunFailedEvent, - IterationRunNextEvent, - IterationRunStartedEvent, - IterationRunSucceededEvent, - NodeRunFailedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - ParallelBranchRunFailedEvent, - ParallelBranchRunStartedEvent, - ParallelBranchRunSucceededEvent, -) from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models.workflow import ConversationVariable logger = logging.getLogger(__name__) -class AdvancedChatAppRunner(AppRunner): +class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ AdvancedChat Application Runner """ @@ -80,8 +50,9 @@ class AdvancedChatAppRunner(AppRunner): :param conversation: conversation :param message: message """ + super().__init__(queue_manager) + self.application_generate_entity = application_generate_entity - self.queue_manager = queue_manager self.conversation = conversation self.message = message @@ -101,10 +72,6 @@ class AdvancedChatAppRunner(AppRunner): if not workflow: raise ValueError('Workflow not initialized') - inputs = self.application_generate_entity.inputs - query = self.application_generate_entity.query - files = self.application_generate_entity.files - user_id = None if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() @@ -113,6 +80,32 @@ class AdvancedChatAppRunner(AppRunner): else: user_id = self.application_generate_entity.user_id + workflow_callbacks: list[WorkflowCallback] = [] + if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + workflow_callbacks.append(WorkflowLoggingCallback()) + + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + node_id = self.application_generate_entity.single_iteration_run.node_id + user_inputs = self.application_generate_entity.single_iteration_run.inputs + + generator = WorkflowEntry.single_step_run_iteration( + workflow=workflow, + node_id=node_id, + user_id=self.application_generate_entity.user_id, + user_inputs=user_inputs, + callbacks=workflow_callbacks + ) + + for event in generator: + # TODO + self._handle_event(workflow_entry, event) + return + + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files + # moderation if self.handle_input_moderation( app_record=app_record, @@ -134,20 +127,16 @@ class AdvancedChatAppRunner(AppRunner): db.session.close() - workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): - workflow_callbacks.append(WorkflowLoggingCallback()) - # Init conversation variables stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id + ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id ) with Session(db.engine) as session: conversation_variables = session.scalars(stmt).all() if not conversation_variables: conversation_variables = [ ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) for variable in workflow.conversation_variables ] @@ -160,9 +149,11 @@ class AdvancedChatAppRunner(AppRunner): system_inputs = { SystemVariable.QUERY: query, SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, + SystemVariable.CONVERSATION_ID: self.conversation.id, SystemVariable.USER_ID: user_id, } + + # init variable pool variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, @@ -174,9 +165,11 @@ class AdvancedChatAppRunner(AppRunner): workflow_entry = WorkflowEntry( workflow=workflow, user_id=self.application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), invoke_from=self.application_generate_entity.invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, @@ -189,181 +182,6 @@ class AdvancedChatAppRunner(AppRunner): for event in generator: self._handle_event(workflow_entry, event) - def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: - """ - Handle event - :param workflow_entry: workflow entry - :param event: event - """ - if isinstance(event, GraphRunStartedEvent): - self._publish_event( - QueueWorkflowStartedEvent( - graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state - ) - ) - elif isinstance(event, GraphRunSucceededEvent): - self._publish_event( - QueueWorkflowSucceededEvent(outputs=event.outputs) - ) - elif isinstance(event, GraphRunFailedEvent): - self._publish_event( - QueueWorkflowFailedEvent(error=event.error) - ) - elif isinstance(event, NodeRunStartedEvent): - self._publish_event( - QueueNodeStartedEvent( - node_execution_id=event.id, - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - start_at=event.route_node_state.start_at, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - predecessor_node_id=event.predecessor_node_id - ) - ) - elif isinstance(event, NodeRunSucceededEvent): - self._publish_event( - QueueNodeSucceededEvent( - node_execution_id=event.id, - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - start_at=event.route_node_state.start_at, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result else {}, - ) - ) - elif isinstance(event, NodeRunFailedEvent): - self._publish_event( - QueueNodeFailedEvent( - node_execution_id=event.id, - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - start_at=event.route_node_state.start_at, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - error=event.route_node_state.node_run_result.error - if event.route_node_state.node_run_result - and event.route_node_state.node_run_result.error - else "Unknown error" - ) - ) - elif isinstance(event, NodeRunStreamChunkEvent): - self._publish_event( - QueueTextChunkEvent( - text=event.chunk_content - ) - ) - elif isinstance(event, NodeRunRetrieverResourceEvent): - self._publish_event( - QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources - ) - ) - elif isinstance(event, ParallelBranchRunStartedEvent): - self._publish_event( - QueueParallelBranchRunStartedEvent( - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id - ) - ) - elif isinstance(event, ParallelBranchRunSucceededEvent): - self._publish_event( - QueueParallelBranchRunStartedEvent( - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id - ) - ) - elif isinstance(event, ParallelBranchRunFailedEvent): - self._publish_event( - QueueParallelBranchRunFailedEvent( - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - error=event.error - ) - ) - elif isinstance(event, IterationRunStartedEvent): - self._publish_event( - QueueIterationStartEvent( - node_execution_id=event.iteration_id, - node_id=event.iteration_node_id, - node_type=event.iteration_node_type, - node_data=event.iteration_node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - start_at=event.start_at, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id, - metadata=event.metadata - ) - ) - elif isinstance(event, IterationRunNextEvent): - self._publish_event( - QueueIterationNextEvent( - node_execution_id=event.iteration_id, - node_id=event.iteration_node_id, - node_type=event.iteration_node_type, - node_data=event.iteration_node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - index=event.index, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - output=event.pre_iteration_output, - ) - ) - elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): - self._publish_event( - QueueIterationCompletedEvent( - node_execution_id=event.iteration_id, - node_id=event.iteration_node_id, - node_type=event.iteration_node_type, - node_data=event.iteration_node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - start_at=event.start_at, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error if isinstance(event, IterationRunFailedEvent) else None - ) - ) - - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow - def handle_input_moderation( self, app_record: App, @@ -450,9 +268,3 @@ class AdvancedChatAppRunner(AppRunner): self._publish_event( QueueStopEvent(stopped_by=stopped_by) ) - - def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 63c1d1de7b..89d9a4deb9 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -240,11 +240,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc graph_runtime_state = None workflow_run = None - for message in self._queue_manager.listen(): - if tts_publisher: - tts_publisher.publish(message=message) - - event = message.event + for queue_message in self._queue_manager.listen(): + event = queue_message.event if isinstance(event, QueuePingEvent): yield self._ping_stream_response() @@ -433,6 +430,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if should_direct_answer: continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) + self._task_state.answer += delta_text yield self._message_to_stream_response(delta_text, self._message.id) elif isinstance(event, QueueMessageReplaceEvent): @@ -454,6 +455,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc else: continue + # publish None when task finished if tts_publisher: tts_publisher.publish(None) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index df40aec154..e293fc7b6f 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Any, Union from flask import Flask, current_app from pydantic import ValidationError @@ -33,7 +33,8 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, @@ -101,13 +102,14 @@ class WorkflowAppGenerator(BaseAppGenerator): ) def _generate( - self, app_model: App, + self, *, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -128,7 +130,7 @@ class WorkflowAppGenerator(BaseAppGenerator): # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, 'context': contextvars.copy_context() @@ -155,7 +157,7 @@ class WorkflowAppGenerator(BaseAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True): + stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -172,10 +174,6 @@ class WorkflowAppGenerator(BaseAppGenerator): if args.get('inputs') is None: raise ValueError('inputs is required') - extras = { - "auto_generate_conversation_name": False - } - # convert to app config app_config = WorkflowAppConfigManager.get_app_config( app_model=app_model, @@ -191,7 +189,9 @@ class WorkflowAppGenerator(BaseAppGenerator): user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={ + "auto_generate_conversation_name": False + }, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args['inputs'] @@ -224,22 +224,12 @@ class WorkflowAppGenerator(BaseAppGenerator): with flask_app.app_context(): try: # workflow app - runner = WorkflowAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager + ) + + runner.run() except GenerateTaskStoppedException: pass except InvokeAuthorizationError: @@ -251,14 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 9a100532b0..0175599938 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,9 +1,10 @@ import logging import os -from typing import Optional, cast +from typing import cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, @@ -15,33 +16,44 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, EndUser -from models.workflow import Workflow logger = logging.getLogger(__name__) -class WorkflowAppRunner: +class WorkflowAppRunner(WorkflowBasedAppRunner): """ Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + + def run(self) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: - user_id = application_generate_entity.user_id + user_id = self.application_generate_entity.user_id app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: @@ -51,82 +63,63 @@ class WorkflowAppRunner: if not workflow: raise ValueError('Workflow not initialized') - inputs = application_generate_entity.inputs - files = application_generate_entity.files - db.session.close() workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + node_id = self.application_generate_entity.single_iteration_run.node_id + user_inputs = self.application_generate_entity.single_iteration_run.inputs + + generator = WorkflowEntry.single_step_run_iteration( + workflow=workflow, + node_id=node_id, + user_id=self.application_generate_entity.user_id, + user_inputs=user_inputs, + callbacks=workflow_callbacks + ) + + for event in generator: + # TODO + self._handle_event(workflow_entry, event) + return + + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + # Create a variable pool. system_inputs = { SystemVariable.FILES: files, SystemVariable.USER_ID: user_id, } + variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, conversation_variables=[], ) - + # RUN WORKFLOW - workflow_entry = WorkflowEntry() - workflow_entry.run( + workflow_entry = WorkflowEntry( workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, ) - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') - - if not app_record.workflow_id: - raise ValueError('Workflow not initialized') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - workflow_callbacks = [] - - workflow_entry = WorkflowEntry() - workflow_entry.single_step_run_iteration_workflow_node( - workflow=workflow, - node_id=node_id, - user_id=user_id, - user_inputs=inputs, + generator = workflow_entry.run( callbacks=workflow_callbacks ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 1b3379d39a..6955844af5 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Generator @@ -15,7 +16,6 @@ from core.app.entities.queue_entities import ( QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, - QueueMessageReplaceEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, @@ -32,10 +32,10 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, StreamResponse, TextChunkStreamResponse, - TextReplaceStreamResponse, WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, + WorkflowStartStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -120,24 +120,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): - workflow_run = self._task_state.workflow_run - if not workflow_run: - raise Exception('Workflow run not found.') - response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status, - outputs=workflow_run.outputs_dict, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()) + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at) ) ) @@ -153,12 +149,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa To stream response. :return: """ + workflow_run_id = None for stream_response in generator: - if not self._task_state.workflow_run: - raise Exception('Workflow run not found.') + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id yield WorkflowAppStreamResponse( - workflow_run_id=self._task_state.workflow_run.id, + workflow_run_id=workflow_run_id, stream_response=stream_response ) @@ -173,17 +170,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ Generator[StreamResponse, None, None]: - publisher = None + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -193,9 +191,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa start_listener_time = time.time() while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.checkAndGetAudio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -213,55 +211,105 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) - event = message.event + graph_runtime_state = None + workflow_run = None - if isinstance(event, QueueErrorEvent): + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_execution_node_start(event) + if not workflow_run: + raise Exception('Workflow run not initialized.') + + workflow_node_execution = self._handle_node_execution_start( + workflow_run=workflow_run, + event=event + ) yield self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) yield self._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) + yield self._workflow_node_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, trace_manager=trace_manager + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None, + conversation_id=None, + trace_manager=trace_manager, ) # save workflow app log @@ -276,17 +324,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if delta_text is None: continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) + self._task_state.answer += delta_text yield self._text_chunk_to_stream_response(delta_text) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._text_replace_to_stream_response(event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() else: continue - if publisher: - publisher.publish(None) + if tts_publisher: + tts_publisher.publish(None) def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: @@ -305,15 +353,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # not save log for debugging return - workflow_app_log = WorkflowAppLog( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, - created_from=created_from.value, - created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), - created_by=self._user.id, - ) + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user' + workflow_app_log.created_by = self._user.id + db.session.add(workflow_app_log) db.session.commit() db.session.close() @@ -330,14 +378,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa ) return response - - def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: - """ - Text replace to stream response. - :param text: text - :return: - """ - return TextReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - text=TextReplaceStreamResponse.Data(text=text) - ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py new file mode 100644 index 0000000000..212fde82f8 --- /dev/null +++ b/api/core/app/apps/workflow_app_runner.py @@ -0,0 +1,228 @@ +from typing import Optional + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowBasedAppRunner(AppRunner): + def __init__(self, queue_manager: AppQueueManager): + self.queue_manager = queue_manager + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent( + graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state + ) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event( + QueueWorkflowSucceededEvent(outputs=event.outputs) + ) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event( + QueueWorkflowFailedEvent(error=event.error) + ) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + predecessor_node_id=event.predecessor_node_id + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result else {}, + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result + and event.route_node_state.node_run_result.error + else "Unknown error" + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + error=event.error + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None + ) + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish( + event, + PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index bc3ec7980c..8ecbd1ecea 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -438,7 +438,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ - workflow_run_id: str + workflow_run_id: Optional[str] = None class AppBlockingResponse(BaseModel): diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 39ca88869a..2f74a180d1 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline: self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Message) -> Exception: + def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception: """ Handle error event. :param event: event diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index fe79fadf66..80b2501da2 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,10 +1,9 @@ from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): - self.node_id = node_id - self.node_type = node_type - self.node_title = node_title + def __init__(self, node_instance: BaseNode, error: str): + self.node_instance = node_instance self.error = error - super().__init__(f"Node {node_title} run failed: {error}") + super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index b54a0f9cdb..d2311b04e9 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,7 +1,8 @@ -from typing import cast +from typing import Any, Mapping, Sequence, cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, @@ -52,9 +53,16 @@ class AnswerNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AnswerNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -66,6 +74,6 @@ class AnswerNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 7c5d2858e8..3807bbb2d5 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -67,19 +67,35 @@ class BaseNode(ABC): yield from result @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict): + def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config :param config: node config :return: """ + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping(node_data) + return cls._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + node_id=node_id, + node_data=node_data + ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: BaseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 6395e91e53..7c066ad083 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, cast +from typing import Any, Mapping, Optional, Sequence, Union, cast from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage @@ -314,13 +314,19 @@ class CodeNode(BaseNode): return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CodeNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index fc24873d16..8299f4d9f2 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, Mapping, Sequence, cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType @@ -32,9 +32,16 @@ class EndNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: EndNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 05690fcc01..6a94a6bd32 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,7 +1,7 @@ import logging from mimetypes import guess_extension from os import path -from typing import cast +from typing import Any, Mapping, Sequence, cast from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar @@ -107,13 +107,19 @@ class HttpRequestNode(BaseNode): return timeout @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = cast(HttpRequestNodeData, node_data) try: http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) @@ -121,7 +127,7 @@ class HttpRequestNode(BaseNode): variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector return variable_mapping except Exception as e: diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index e9c325416d..feb0175a74 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, Mapping, Sequence, cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType @@ -99,9 +99,16 @@ class IfElseNode(BaseNode): return data @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c5190440bd..f7904aa836 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,10 +1,11 @@ import logging -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from datetime import datetime, timezone from typing import Any, cast from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, @@ -287,12 +288,67 @@ class IterationNode(BaseNode): variable_pool.remove([self.node_id, 'item']) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - 'input_selector': node_data.iterator_selector, + variable_mapping = { + f'{node_id}.input_selector': node_data.iterator_selector, } + + # init graph + iteration_graph = Graph.init( + graph_config=graph_config, + root_node_id=node_data.start_node_id + ) + + if not iteration_graph: + raise ValueError('iteration graph not found') + + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): + if sub_node_config.get('data', {}).get('iteration_id') != node_id: + continue + + # variable selector to variable mapping + try: + # Get node class + from core.workflow.nodes.node_mapping import node_classes + node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + if not node_cls: + continue + + node_cls = cast(BaseNode, node_cls) + + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + config=sub_node_config + ) + sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + except NotImplementedError: + sub_node_variable_mapping = {} + + # remove iteration variables + sub_node_variable_mapping = { + sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items() + if value[0] != node_id + } + + variable_mapping.update(sub_node_variable_mapping) + + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() + if value[0] not in iteration_graph.node_ids + } + + return variable_mapping diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7c6811d0e8..1e9ff9ff79 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,5 +1,5 @@ import logging -from typing import Any, cast +from typing import Any, Mapping, Sequence, cast from sqlalchemy import func @@ -232,11 +232,21 @@ class KnowledgeRetrievalNode(BaseNode): return context_list @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: KnowledgeRetrievalNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {} - variable_mapping['query'] = node_data.query_variable_selector + variable_mapping[node_id + '.query'] = node_data.query_variable_selector return variable_mapping def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 4e5ecb42b4..5fdf2456df 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,7 +1,7 @@ import json from collections.abc import Generator from copy import deepcopy -from typing import Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast from pydantic import BaseModel @@ -678,13 +678,19 @@ class LLMNode(BaseNode): db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = cast(LLMNodeData, node_data) prompt_template = node_data.prompt_template variable_selectors = [] @@ -734,6 +740,10 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index eb28052f72..f4ff251ead 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,6 +1,6 @@ import json import uuid -from typing import Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -701,15 +701,19 @@ class ParameterExtractorNode(LLMNode): return self._model_instance, self._model_config @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[ - str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ParameterExtractorNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = node_data - variable_mapping = { 'query': node_data.query } @@ -719,4 +723,8 @@ class ParameterExtractorNode(LLMNode): for selector in variable_template_parser.extract_variable_selectors(): variable_mapping[selector.variable] = selector.value_selector + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index dc757b7608..97996872d9 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union, cast +from typing import Any, Mapping, Optional, Sequence, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -137,9 +137,19 @@ class QuestionClassifierNode(LLMNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: QuestionClassifierNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {'query': node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: @@ -147,6 +157,11 @@ class QuestionClassifierNode(LLMNode): variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = { + node_id + '.' + key: value for key, value in variable_mapping.items() + } + return variable_mapping @classmethod diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 61880b82ff..826c3526e6 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,4 +1,5 @@ +from typing import Any, Mapping, Sequence from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode @@ -28,9 +29,16 @@ class StartNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: StartNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 3406763b97..4a19792c64 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,5 +1,5 @@ import os -from typing import Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType @@ -77,13 +77,19 @@ class TemplateTransformNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[ - str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c10ee542f1..ebbf25c823 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -221,9 +221,16 @@ class ToolNode(BaseNode): return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ @@ -239,4 +246,8 @@ class ToolNode(BaseNode): elif input.type == 'constant': pass + result = { + node_id + '.' + key: value for key, value in result.items() + } + return result diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 61ee59ec92..186bbce2af 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, Mapping, Sequence, cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType @@ -48,5 +48,17 @@ class VariableAggregatorNode(BaseNode): ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: VariableAssignerNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ return {} diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 3c714069d3..8ab5d27eb2 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,6 +1,8 @@ import logging +import time +import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional, Type, cast from configs import dify_config from core.app.app_config.entities import FileExtraConfig @@ -8,13 +10,18 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable, UserFrom +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent +from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.node_mapping import node_classes from models.workflow import ( @@ -32,18 +39,17 @@ class WorkflowEntry: user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, - user_inputs: Mapping[str, Any], - system_inputs: Mapping[SystemVariable, Any], - call_depth: int = 0 + call_depth: int, + variable_pool: VariablePool ) -> None: """ :param workflow: Workflow instance :param user_id: user id :param user_from: user from :param invoke_from: invoke from service-api, web-app, debugger, explore - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files :param call_depth: call depth + :param variable_pool: variable pool + :param single_step_run_iteration_id: single step run iteration id """ # fetch workflow graph graph_config = workflow.graph_dict @@ -71,13 +77,6 @@ class WorkflowEntry: if not graph: raise ValueError('graph not found in workflow') - # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - ) - # init workflow run state self.graph_engine = GraphEngine( tenant_id=workflow.tenant_id, @@ -134,10 +133,160 @@ class WorkflowEntry: ) return - def single_step_run(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: + @classmethod + def single_step_run_iteration( + cls, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict, + callbacks: Sequence[WorkflowCallback], + ) -> Generator[GraphEngineEvent, None, None]: + """ + Single step run workflow node iteration + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError('workflow graph not found') + + graph_config = cast(dict[str, Any], graph_config) + + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + + # filter nodes only in iteration + node_configs = [ + node for node in graph_config.get('nodes', []) + if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + ] + + graph_config['nodes'] = node_configs + + node_ids = [node.get('id') for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge for edge in graph_config.get('edges', []) + if (edge.get('source') is None or edge.get('source') in node_ids) + and (edge.get('target') is None or edge.get('target') in node_ids) + ] + + graph_config['edges'] = edge_configs + + # init graph + graph = Graph.init( + graph_config=graph_config, + root_node_id=node_id + ) + + if not graph: + raise ValueError('graph not found in workflow') + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get('id') == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError('iteration node id not found in workflow graph') + + # Get node class + node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls._mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=IterationNodeData(**iteration_node_config.get('data', {})) + ) + + # init workflow run state + graph_engine = GraphEngine( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=1, + graph=graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + ) + + try: + # run workflow + generator = graph_engine.run() + for event in generator: + if callbacks: + for callback in callbacks: + callback.on_event( + graph=graph_engine.graph, + graph_init_params=graph_engine.init_params, + graph_runtime_state=graph_engine.graph_runtime_state, + event=event + ) + yield event + except GenerateTaskStoppedException: + pass + except Exception as e: + logger.exception("Unknown Error when workflow entry running") + if callbacks: + for callback in callbacks: + callback.on_event( + graph=graph_engine.graph, + graph_init_params=graph_engine.init_params, + graph_runtime_state=graph_engine.graph_runtime_state, + event=GraphRunFailedEvent( + error=str(e) + ) + ) + return + + @classmethod + def single_step_run( + cls, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict + ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: """ Single step run workflow node :param workflow: Workflow instance @@ -168,61 +317,74 @@ class WorkflowEntry: # Get node class node_type = NodeType.value_of(node_config.get('data', {}).get('type')) node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) if not node_cls: raise ValueError(f'Node class not found for node type {node_type}') + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + # init graph + graph = Graph.init( + graph_config=workflow.graph_dict + ) # init workflow run state - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + node_instance: BaseNode = node_cls( + id=str(uuid.uuid4()), config=node_config, - workflow_call_depth=0 + graph_init_params=GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ), + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter() + ) ) try: - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - ) - # variable selector to variable mapping try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=node_config + ) except NotImplementedError: variable_mapping = {} - self._mapping_user_inputs_to_variable_pool( + cls._mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, - node_instance=node_instance + node_type=node_type, + node_data=node_instance.node_data ) # run node - node_run_result = node_instance.run( - variable_pool=variable_pool - ) + generator = node_instance.run() - # sign output files - node_run_result.outputs = self.handle_special_values(node_run_result.outputs) + return node_instance, generator except Exception as e: raise WorkflowNodeRunFailedError( - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_title=node_instance.node_data.title, + node_instance=node_instance, error=str(e) ) - return node_instance, node_run_result - @classmethod def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: """ @@ -250,33 +412,49 @@ class WorkflowEntry: return new_value - def _mapping_user_inputs_to_variable_pool(self, - variable_mapping: dict, - user_inputs: dict, - variable_pool: VariablePool, - tenant_id: str, - node_instance: BaseNode): - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {variable_key} not found in user inputs.') + @classmethod + def _mapping_user_inputs_to_variable_pool( + cls, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_type: NodeType, + node_data: BaseNodeData + ) -> None: + for node_variable, variable_selector in variable_mapping.items(): + # fetch node id and variable key from node_variable + node_variable_list = node_variable.split('.') + if len(node_variable_list) < 1: + raise ValueError(f'Invalid node variable {node_variable}') + + node_variable_key = node_variable_list[1:] + + if ( + node_variable_key not in user_inputs + or node_variable not in user_inputs + ) and not variable_pool.get(variable_selector): + raise ValueError(f'Variable key {node_variable} not found in user inputs.') # fetch variable node id from variable selector variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] + variable_key_list = cast(list[str], variable_key_list) - # get value - value = user_inputs.get(variable_key) + # get input value + input_value = user_inputs.get(node_variable) + if not input_value: + input_value = user_inputs.get(node_variable_key) # FIXME: temp fix for image type - if node_instance.node_type == NodeType.LLM: + if node_type == NodeType.LLM: new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data + if isinstance(input_value, list): node_data = cast(LLMNodeData, node_data) detail = node_data.vision.configs.detail if node_data.vision.configs else None - for item in value: + for item in input_value: if isinstance(item, dict) and 'type' in item and item['type'] == 'image': transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) file = FileVar( @@ -294,4 +472,4 @@ class WorkflowEntry: value = new_value # append variable and value to variable pool - variable_pool.add([variable_node_id] + variable_key_list, value) + variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 15f66b0f81..436274daf7 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -10,6 +10,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow from services.workflow_service import WorkflowService @@ -95,11 +96,10 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, - user: Union[Account, EndUser], + user: Account, node_id: str, args: Any, streaming: bool = True): - # TODO if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( @@ -145,7 +145,7 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: """ Get workflow :param app_model: app model diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 98b84cb98a..5ee8ce4700 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -8,8 +8,9 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.node_mapping import node_classes from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -213,81 +214,62 @@ class WorkflowService: raise ValueError('Workflow not initialized') # run draft workflow node - workflow_entry = WorkflowEntry() start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_entry.single_step_run( + node_instance, generator = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ) + + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break + + if not node_run_result: + raise ValueError('Node run failed with no run result') + + run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=e.node_id, - node_type=e.node_type.value, - title=e.node_title, - status=WorkflowNodeExecutionStatus.FAILED.value, - error=e.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - db.session.add(workflow_node_execution) - db.session.commit() + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error - return workflow_node_execution + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = app_model.tenant_id + workflow_node_execution.app_id = app_model.id + workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type.value + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by = account.id + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None + workflow_node_execution.process_data = json.dumps(node_run_result.process_data) if node_run_result.process_data else None + workflow_node_execution.outputs = json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value else: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - status=node_run_result.status.value, - error=node_run_result.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error db.session.add(workflow_node_execution) db.session.commit() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index d27f96d8ff..06947d6439 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -22,6 +22,7 @@ def test_execute_code(setup_code_executor_mock): # trim first 4 spaces at the beginning of each line code = '\n'.join([line[4:] for line in code.split('\n')]) node = CodeNode( + id='test', tenant_id='1', app_id='1', workflow_id='1',