diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index ae86463407..4e91508786 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -3,6 +3,20 @@ from typing import Any, Optional from pydantic import BaseModel +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode from models.workflow import WorkflowNodeExecutionStatus @@ -41,6 +55,25 @@ class NodeType(Enum): raise ValueError(f'invalid node type value {value}') +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode +} + + class SystemVariable(Enum): """ System Variables. @@ -90,3 +123,23 @@ class NodeRunResult(BaseModel): edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + + +class UserFrom(Enum): + """ + User from + """ + ACCOUNT = "account" + END_USER = "end-user" + + @classmethod + def value_of(cls, value: str) -> "UserFrom": + """ + Value of + :param value: value + :return: + """ + for item in cls: + if item.value == value: + return item + raise ValueError(f"Invalid value: {value}") diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py new file mode 100644 index 0000000000..a9a984f83d --- /dev/null +++ b/api/core/workflow/graph_engine/entities/event.py @@ -0,0 +1,116 @@ +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + +from core.workflow.entities.node_entities import NodeRunResult + + +class GraphEngineEvent(BaseModel): + pass + +########################################### +# Graph Events +########################################### + + +class BaseGraphEvent(GraphEngineEvent): + pass + + +class GraphRunStartedEvent(BaseGraphEvent): + pass + + +class GraphRunBackToRootEvent(BaseGraphEvent): + pass + + +class GraphRunSucceededEvent(BaseGraphEvent): + pass + + +class GraphRunFailedEvent(BaseGraphEvent): + reason: str = Field(..., description="failed reason") + + +########################################### +# Node Events +########################################### + + +class BaseNodeEvent(GraphEngineEvent): + node_id: str = Field(..., description="node id") + parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel") + # iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration") + + +class NodeRunStartedEvent(BaseNodeEvent): + pass + + +class NodeRunStreamChunkEvent(BaseNodeEvent): + chunk_content: str = Field(..., description="chunk content") + + +class NodeRunRetrieverResourceEvent(BaseNodeEvent): + retriever_resources: list[dict] = Field(..., description="retriever resources") + + +class NodeRunSucceededEvent(BaseNodeEvent): + run_result: NodeRunResult = Field(..., description="run result") + + +class NodeRunFailedEvent(BaseNodeEvent): + run_result: NodeRunResult = Field(..., description="run result") + reason: str = Field("", description="failed reason") + + @model_validator(mode='before') + def init_reason(cls, values: dict) -> dict: + if not values.get("reason"): + values["reason"] = values.get("run_result").error or "Unknown error" + return values + + +########################################### +# Parallel Events +########################################### + + +class BaseParallelEvent(GraphEngineEvent): + parallel_id: str = Field(..., description="parallel id") + + +class ParallelRunStartedEvent(BaseParallelEvent): + pass + + +class ParallelRunSucceededEvent(BaseParallelEvent): + pass + + +class ParallelRunFailedEvent(BaseParallelEvent): + reason: str = Field(..., description="failed reason") + + +########################################### +# Iteration Events +########################################### + + +class BaseIterationEvent(GraphEngineEvent): + iteration_id: str = Field(..., description="iteration id") + + +class IterationRunStartedEvent(BaseIterationEvent): + pass + + +class IterationRunSucceededEvent(BaseIterationEvent): + pass + + +class IterationRunFailedEvent(BaseIterationEvent): + reason: str = Field(..., description="failed reason") + + +InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index deddd255c0..dab4e30da6 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -8,48 +8,37 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition class GraphEdge(BaseModel): - source_node_id: str - """source node id""" - - target_node_id: str - """target node id""" - - run_condition: Optional[RunCondition] = None - """condition to run the edge""" + source_node_id: str = Field(..., description="source node id") + target_node_id: str = Field(..., description="target node id") + run_condition: Optional[RunCondition] = Field(None, description="run condition") class GraphParallel(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - """random uuid parallel id""" - - start_from_node_id: str - """start from node id""" - - end_to_node_id: Optional[str] = None - """end to node id""" - - parent_parallel_id: Optional[str] = None - """parent parallel id if exists""" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") + start_from_node_id: str = Field(..., description="start from node id") + parent_parallel_id: Optional[str] = Field(None, description="parent parallel id") + end_to_node_id: Optional[str] = Field(None, description="end to node id") class Graph(BaseModel): - root_node_id: str - """root node id of the graph""" - - node_ids: list[str] = Field(default_factory=list) - """graph node ids""" - - node_id_config_mapping: dict[str, dict] = Field(default_factory=list) - """node configs mapping (node id: node config)""" - - edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict) - """graph edge mapping (source node id: edges)""" - - parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict) - """graph parallel mapping (parallel id: parallel)""" - - node_parallel_mapping: dict[str, str] = Field(default_factory=dict) - """graph node parallel mapping (node id: parallel id)""" + root_node_id: str = Field(..., description="root node id of the graph") + node_ids: list[str] = Field(default_factory=list, description="graph node ids") + node_id_config_mapping: dict[str, dict] = Field( + default_factory=list, + description="node configs mapping (node id: node config)" + ) + edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, + description="graph edge mapping (source node id: edges)" + ) + parallel_mapping: dict[str, GraphParallel] = Field( + default_factory=dict, + description="graph parallel mapping (parallel id: parallel)" + ) + node_parallel_mapping: dict[str, str] = Field( + default_factory=dict, + description="graph node parallel mapping (node id: parallel id)" + ) @classmethod def init(cls, diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py new file mode 100644 index 0000000000..d32d3eb4f3 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import UserFrom +from models.workflow import WorkflowType + + +class GraphInitParams(BaseModel): + # init params + tenant_id: str = Field(..., description="tenant / workspace id") + app_id: str = Field(..., description="app id") + workflow_type: WorkflowType = Field(..., description="workflow type") + workflow_id: str = Field(..., description="workflow id") + user_id: str = Field(..., description="user id") + user_from: UserFrom = Field(..., description="user from, account or end-user") + invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index 1eeff95d13..4d17d277e3 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -1,25 +1,14 @@ - from pydantic import BaseModel, Field -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState -from core.workflow.nodes.base_node import UserFrom class GraphRuntimeState(BaseModel): - # init params - tenant_id: str - app_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - call_depth: int + variable_pool: VariablePool = Field(..., description="variable pool") - variable_pool: VariablePool + start_at: float = Field(..., description="start time") + total_tokens: int = Field(0, description="total tokens") + node_run_steps: int = Field(0, description="node run steps") - start_at: float - total_tokens: int = 0 - node_run_steps: int = 0 - - node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState) + node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 76a7080a87..abda39b6e3 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -2,7 +2,7 @@ import logging import queue import time from collections.abc import Generator -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from typing import Optional, cast from flask import Flask, current_app @@ -13,10 +13,20 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, +) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.base_node import UserFrom, node_classes from extensions.ext_database import db +from models.workflow import WorkflowType thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun") logger = logging.getLogger(__name__) @@ -25,6 +35,8 @@ logger = logging.getLogger(__name__) class GraphEngine: def __init__(self, tenant_id: str, app_id: str, + workflow_type: WorkflowType, + workflow_id: str, user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, @@ -33,13 +45,18 @@ class GraphEngine: variable_pool: VariablePool, callbacks: list[BaseWorkflowCallback]) -> None: self.graph = graph - self.graph_runtime_state = GraphRuntimeState( + self.init_params = GraphInitParams( tenant_id=tenant_id, app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, user_id=user_id, user_from=user_from, invoke_from=invoke_from, - call_depth=call_depth, + call_depth=call_depth + ) + + self.graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, start_at=time.perf_counter() ) @@ -55,31 +72,31 @@ class GraphEngine: # TODO convert generator to result pass - def run(self) -> Generator: - # TODO trigger graph run start event + def run(self) -> Generator[GraphEngineEvent, None, None]: + # trigger graph run start event + yield GraphRunStartedEvent() try: - # TODO run graph - rst = self._run(start_node_id=self.graph.root_node_id) - except GraphRunFailedError as e: - # TODO self._graph_run_failed( - # error=e.error, - # callbacks=callbacks - # ) - pass + # run graph + generator = self._run(start_node_id=self.graph.root_node_id) + for item in generator: + yield item + if isinstance(item, NodeRunFailedEvent): + yield GraphRunFailedEvent(reason=item.reason) + return + + # trigger graph run success event + yield GraphRunSucceededEvent() + except (GraphRunFailedError, NodeRunFailedError) as e: + yield GraphRunFailedEvent(reason=e.error) + return except Exception as e: - # TODO self._workflow_run_failed( - # error=str(e), - # callbacks=callbacks - # ) - pass + yield GraphRunFailedEvent(reason=str(e)) + return - # TODO trigger graph run success event - - yield rst - - def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None): + def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: next_node_id = start_node_id + previous_node_id = None while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: @@ -92,10 +109,18 @@ class GraphEngine: ): raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) - # run node TODO generator - yield from self._run_node(node_id=next_node_id) + try: + # run node + yield from self._run_node( + node_id=next_node_id, + previous_node_id=previous_node_id, + parallel_id=in_parallel_id + ) + except Exception as e: + yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e)) + return - # todo if failed, break + previous_node_id = next_node_id # get next node ids edge_mappings = self.graph.edge_mapping.get(next_node_id) @@ -135,11 +160,11 @@ class GraphEngine: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id) if not parallel_id: - raise GraphRunFailedError('Node related parallel not found.') + raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.') parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: - raise GraphRunFailedError('Parallel not found.') + raise GraphRunFailedError(f'Parallel {parallel_id} not found.') # run parallel nodes, run in new thread and use queue to get results q: queue.Queue = queue.Queue() @@ -149,8 +174,9 @@ class GraphEngine: for edge in edge_mappings: futures.append(thread_pool.submit( self._run_parallel_node, - flask_app=current_app._get_current_object(), - parallel_start_node_id=edge.source_node_id, + flask_app=current_app._get_current_object(), # type: ignore + parallel_id=parallel_id, + parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel q=q )) @@ -165,8 +191,9 @@ class GraphEngine: except queue.Empty: continue - for future in as_completed(futures): - future.result() + # not necessary + # for future in as_completed(futures): + # future.result() # get final node id final_node_id = parallel.end_to_node_id @@ -178,48 +205,61 @@ class GraphEngine: if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: break - def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None: + def _run_parallel_node(self, + flask_app: Flask, + parallel_id: str, + parallel_start_node_id: str, + q: queue.Queue) -> None: """ Run parallel nodes """ with flask_app.app_context(): try: - in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id) - if not in_parallel_id: - q.put(None) - return - - # run node TODO generator - rst = self._run( + # run node + generator = self._run( start_node_id=parallel_start_node_id, - in_parallel_id=in_parallel_id + in_parallel_id=parallel_id ) - if not rst: - q.put(None) - return - - for item in rst: - q.put(item) - - q.put(None) + if generator: + for item in generator: + q.put(item) except Exception: logger.exception("Unknown Error when generating in parallel") finally: + q.put(None) db.session.remove() - def _run_node(self, node_id: str) -> Generator: + def _run_node(self, + node_id: str, + previous_node_id: Optional[str] = None, + parallel_id: Optional[str] = None + ) -> Generator[GraphEngineEvent, None, None]: """ Run node """ # get node config node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError('Node not found.') + raise GraphRunFailedError(f'Node {node_id} config not found.') - # todo convert to specific node + # convert to specific node + node_type = NodeType.value_of(node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + if not node_cls: + raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') - # todo trigger node run start event + # init workflow run state + node_instance = node_cls( # type: ignore + config=node_config, + graph_init_params=self.init_params, + graph=self.graph, + graph_runtime_state=self.graph_runtime_state, + previous_node_id=previous_node_id + ) + + # trigger node run start event + yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id) db.session.close() @@ -229,28 +269,25 @@ class GraphEngine: try: # run node - rst = node.run( - graph_runtime_state=self.graph_runtime_state, - graph=self.graph, - callbacks=self.callbacks - ) + generator = node_instance.run() - yield from rst + yield from generator # todo record state + + # trigger node run success event + yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id) except GenerateTaskStoppedException as e: - # TODO yield failed - # todo trigger node run failed event - pass + # trigger node run failed event + yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e)) + return except Exception as e: - # logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - # TODO yield failed - # todo trigger node run failed event - pass - - # todo trigger node run success event - - db.session.close() + # todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") + # trigger node run failed event + yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e)) + return + finally: + db.session.close() def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ @@ -265,3 +302,8 @@ class GraphEngine: class GraphRunFailedError(Exception): def __init__(self, error: str): self.error = error + + +class NodeRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index ed2bb70711..d4ab6a5f5f 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,33 +1,17 @@ from abc import ABC, abstractmethod -from enum import Enum +from collections.abc import Generator from typing import Optional from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.iterable_node import IterableNodeMixin - - -class UserFrom(Enum): - """ - User from - """ - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") +from models.workflow import WorkflowType class BaseNode(ABC): @@ -36,64 +20,66 @@ class BaseNode(ABC): tenant_id: str app_id: str + workflow_type: WorkflowType workflow_id: str user_id: str user_from: UserFrom invoke_from: InvokeFrom - workflow_call_depth: int + graph: Graph + graph_runtime_state: GraphRuntimeState + previous_node_id: Optional[str] = None node_id: str node_data: BaseNodeData - node_run_result: Optional[NodeRunResult] = None - callbacks: list[BaseWorkflowCallback] - - def __init__(self, tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, + def __init__(self, config: dict, - callbacks: list[BaseWorkflowCallback] = None, - workflow_call_depth: int = 0) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - self.workflow_id = workflow_id - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth + graph_init_params: GraphInitParams, + graph: Graph, + graph_runtime_state: GraphRuntimeState, + previous_node_id: Optional[str] = None) -> None: + self.tenant_id = graph_init_params.tenant_id + self.app_id = graph_init_params.app_id + self.workflow_type = graph_init_params.workflow_type + self.workflow_id = graph_init_params.workflow_id + self.user_id = graph_init_params.user_id + self.user_from = graph_init_params.user_from + self.invoke_from = graph_init_params.invoke_from + self.workflow_call_depth = graph_init_params.call_depth + self.graph = graph + self.graph_runtime_state = graph_runtime_state + self.previous_node_id = previous_node_id - self.node_id = config.get("id") - if not self.node_id: + node_id = config.get("id") + if not node_id: raise ValueError("Node ID is required.") + self.node_id = node_id self.node_data = self._node_data_cls(**config.get("data", {})) - self.callbacks = callbacks or [] @abstractmethod - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) \ + -> NodeRunResult | Generator[RunEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ raise NotImplementedError - def run(self, variable_pool: VariablePool) -> NodeRunResult: + def run(self) -> Generator[RunEvent, None, None]: """ Run node entry - :param variable_pool: variable pool :return: """ - result = self._run( - variable_pool=variable_pool - ) + result = self._run() - self.node_run_result = result - return result + if isinstance(result, NodeRunResult): + yield RunCompletedEvent( + run_result=result + ) + else: + yield from result def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: """ @@ -102,13 +88,13 @@ class BaseNode(ABC): :param value_selector: value selector :return: """ + # TODO remove callbacks if self.callbacks: for callback in self.callbacks: callback.on_node_text_chunk( node_id=self.node_id, text=text, metadata={ - "node_type": self.node_type, "value_selector": value_selector } ) diff --git a/api/core/workflow/nodes/event.py b/api/core/workflow/nodes/event.py new file mode 100644 index 0000000000..276c13a6d4 --- /dev/null +++ b/api/core/workflow/nodes/event.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult + + +class RunCompletedEvent(BaseModel): + run_result: NodeRunResult = Field(..., description="run result") + + +class RunStreamChunkEvent(BaseModel): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") + + +class RunRetrieverResourceEvent(BaseModel): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index af928517d9..9b498062cc 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -3,15 +3,16 @@ from collections.abc import Generator from copy import deepcopy from typing import Optional, cast +from pydantic import BaseModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, @@ -23,9 +24,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, @@ -43,13 +47,13 @@ class LLMNode(BaseNode): _node_data_cls = LLMNodeData node_type = NodeType.LLM - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> Generator[RunEvent, None, None]: """ Run node - :param variable_pool: variable pool :return: """ node_data = cast(LLMNodeData, deepcopy(self.node_data)) + variable_pool = self.graph_runtime_state.variable_pool node_inputs = None process_data = None @@ -76,10 +80,17 @@ class LLMNode(BaseNode): node_inputs['#files#'] = [file.to_dict() for file in files] # fetch context value - context = self._fetch_context(node_data, variable_pool) + generator = self._fetch_context(node_data, variable_pool) + context = None + for event in generator: + if isinstance(event, RunRetrieverResourceEvent): + context = event.context + yield NodeRunRetrieverResourceEvent( + retriever_resources=event.retriever_resources + ) if context: - node_inputs['#context#'] = context + node_inputs['#context#'] = context # type: ignore # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) @@ -90,7 +101,7 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) + query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) # type: ignore if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, @@ -109,41 +120,57 @@ class LLMNode(BaseNode): } # handle invoke result - result_text, usage = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = '' + usage = LLMUsage.empty_usage() + for event in generator: + if isinstance(event, RunStreamChunkEvent): + yield event + elif isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + break except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data + ) ) + return outputs = { 'text': result_text, 'usage': jsonable_encoder(usage) } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency + } + ) ) def _invoke_llm(self, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: list[str]) -> tuple[str, LLMUsage]: + stop: Optional[list[str]] = None) \ + -> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]: """ Invoke large language model :param node_data_model: node data model @@ -163,30 +190,41 @@ class LLMNode(BaseNode): ) # handle invoke result - text, usage = self._handle_invoke_result( + generator = self._handle_invoke_result( invoke_result=invoke_result ) + usage = LLMUsage.empty_usage() + for event in generator: + yield event + if isinstance(event, ModelInvokeCompleted): + usage = event.usage + # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage - - def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ + -> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]: """ Handle invoke result :param invoke_result: invoke result :return: """ + if isinstance(invoke_result, LLMResult): + return + model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = '' usage = None for result in invoke_result: text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) + yield RunStreamChunkEvent( + chunk_content=text, + from_variable_selector=[self.node_id, 'text'] + ) if not model: model = result.model @@ -200,11 +238,14 @@ class LLMNode(BaseNode): if not usage: usage = LLMUsage.empty_usage() - return full_text, usage - - def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + yield ModelInvokeCompleted( + text=full_text, + usage=usage + ) + + def _transform_chat_messages(self, + messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: """ Transform chat messages @@ -213,13 +254,13 @@ class LLMNode(BaseNode): """ if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2': + if messages.edition_type == 'jinja2' and messages.jinja2_text: messages.text = messages.jinja2_text return messages for message in messages: - if message.edition_type == 'jinja2': + if message.edition_type == 'jinja2' and message.jinja2_text: message.text = message.jinja2_text return messages @@ -249,13 +290,13 @@ class LLMNode(BaseNode): # check if it's a context structure if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: return d['content'] - + # else, parse the dict try: return json.dumps(d, ensure_ascii=False) except Exception: return str(d) - + if isinstance(value, str): value = value elif isinstance(value, list): @@ -319,7 +360,7 @@ class LLMNode(BaseNode): inputs[variable_selector.variable] = variable_value - return inputs + return inputs # type: ignore def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: """ @@ -337,7 +378,7 @@ class LLMNode(BaseNode): return files - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: + def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]: """ Fetch context :param node_data: node data @@ -353,7 +394,10 @@ class LLMNode(BaseNode): context_value = variable_pool.get_variable_value(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): - return context_value + yield RunRetrieverResourceEvent( + retriever_resources=[], + context=context_value + ) elif isinstance(context_value, list): context_str = '' original_retriever_resource = [] @@ -370,17 +414,10 @@ class LLMNode(BaseNode): if retriever_resource: original_retriever_resource.append(retriever_resource) - if self.callbacks and original_retriever_resource: - for callback in self.callbacks: - callback.on_event( - event=QueueRetrieverResourcesEvent( - retriever_resources=original_retriever_resource - ) - ) - - return context_str.strip() - - return None + yield RunRetrieverResourceEvent( + retriever_resources=original_retriever_resource, + context=context_str.strip() + ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: """ @@ -561,7 +598,8 @@ class LLMNode(BaseNode): if not isinstance(prompt_message.content, str): prompt_message_content = [] for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): + if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance( + content_item, ImagePromptMessageContent): # Override vision config if LLM node has vision config if vision_detail: content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) @@ -633,13 +671,13 @@ class LLMNode(BaseNode): db.session.commit() @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data :return: """ - + node_data = cast(LLMNodeData, node_data) prompt_template = node_data.prompt_template variable_selectors = [] @@ -727,3 +765,11 @@ class LLMNode(BaseNode): } } } + + +class ModelInvokeCompleted(BaseModel): + """ + Model invoke completed + """ + text: str + usage: LLMUsage diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 76f3dec836..fda679adc1 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder @@ -16,7 +17,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, @@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData node_type = NodeType.QUESTION_CLASSIFIER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) node_data = cast(QuestionClassifierNodeData, node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) @@ -62,12 +64,21 @@ class QuestionClassifierNode(LLMNode): ) # handle invoke result - result_text, usage = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop ) + + result_text = '' + usage = LLMUsage.empty_usage() + for event in generator: + if isinstance(event, ModelInvokeCompleted): + result_text = event.text + usage = event.usage + break + category_name = node_data.classes[0].name category_id = node_data.classes[0].id try: diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 840b12878c..5a54016aab 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -17,47 +17,17 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes from core.workflow.nodes.iteration.entities import IterationState -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, + WorkflowType, ) -node_classes = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode -} - logger = logging.getLogger(__name__) @@ -115,6 +85,8 @@ class WorkflowEntry: graph_engine = GraphEngine( tenant_id=workflow.tenant_id, app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, user_id=user_id, user_from=user_from, invoke_from=invoke_from, @@ -692,7 +664,7 @@ class WorkflowEntry: # add steps workflow_run_state.workflow_node_steps += 1 - def _workflow_iteration_next(self, graph: dict, + def _workflow_iteration_next(self, graph: dict, current_iteration_node: BaseIterationNode, workflow_run_state: WorkflowRunState, callbacks: list[BaseWorkflowCallback] = None) -> None: