graph engine implement

This commit is contained in:
takatost 2024-07-15 23:40:02 +08:00
parent 821e09b259
commit 00fb23d0c9
11 changed files with 511 additions and 270 deletions

View File

@ -3,6 +3,20 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -41,6 +55,25 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}') raise ValueError(f'invalid node type value {value}')
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
class SystemVariable(Enum): class SystemVariable(Enum):
""" """
System Variables. System Variables.
@ -90,3 +123,23 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@ -0,0 +1,116 @@
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.workflow.entities.node_entities import NodeRunResult
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunBackToRootEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
pass
class GraphRunFailedEvent(BaseGraphEvent):
reason: str = Field(..., description="failed reason")
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
node_id: str = Field(..., description="node id")
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class NodeRunStartedEvent(BaseNodeEvent):
pass
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
class NodeRunSucceededEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
class NodeRunFailedEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
reason: str = Field("", description="failed reason")
@model_validator(mode='before')
def init_reason(cls, values: dict) -> dict:
if not values.get("reason"):
values["reason"] = values.get("run_result").error or "Unknown error"
return values
###########################################
# Parallel Events
###########################################
class BaseParallelEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
class ParallelRunStartedEvent(BaseParallelEvent):
pass
class ParallelRunSucceededEvent(BaseParallelEvent):
pass
class ParallelRunFailedEvent(BaseParallelEvent):
reason: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration id")
class IterationRunStartedEvent(BaseIterationEvent):
pass
class IterationRunSucceededEvent(BaseIterationEvent):
pass
class IterationRunFailedEvent(BaseIterationEvent):
reason: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent

View File

@ -8,48 +8,37 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
class GraphEdge(BaseModel): class GraphEdge(BaseModel):
source_node_id: str source_node_id: str = Field(..., description="source node id")
"""source node id""" target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = Field(None, description="run condition")
target_node_id: str
"""target node id"""
run_condition: Optional[RunCondition] = None
"""condition to run the edge"""
class GraphParallel(BaseModel): class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
"""random uuid parallel id""" start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
start_from_node_id: str end_to_node_id: Optional[str] = Field(None, description="end to node id")
"""start from node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if exists"""
class Graph(BaseModel): class Graph(BaseModel):
root_node_id: str root_node_id: str = Field(..., description="root node id of the graph")
"""root node id of the graph""" node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
node_ids: list[str] = Field(default_factory=list) default_factory=list,
"""graph node ids""" description="node configs mapping (node id: node config)"
)
node_id_config_mapping: dict[str, dict] = Field(default_factory=list) edge_mapping: dict[str, list[GraphEdge]] = Field(
"""node configs mapping (node id: node config)""" default_factory=dict,
description="graph edge mapping (source node id: edges)"
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict) )
"""graph edge mapping (source node id: edges)""" parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict) description="graph parallel mapping (parallel id: parallel)"
"""graph parallel mapping (parallel id: parallel)""" )
node_parallel_mapping: dict[str, str] = Field(
node_parallel_mapping: dict[str, str] = Field(default_factory=dict) default_factory=dict,
"""graph node parallel mapping (node id: parallel id)""" description="graph node parallel mapping (node id: parallel id)"
)
@classmethod @classmethod
def init(cls, def init(cls,

View File

@ -0,0 +1,17 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@ -1,25 +1,14 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel): class GraphRuntimeState(BaseModel):
# init params variable_pool: VariablePool = Field(..., description="variable pool")
tenant_id: str
app_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
call_depth: int
variable_pool: VariablePool start_at: float = Field(..., description="start time")
total_tokens: int = Field(0, description="total tokens")
node_run_steps: int = Field(0, description="node run steps")
start_at: float node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
total_tokens: int = 0
node_run_steps: int = 0
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)

View File

@ -2,7 +2,7 @@ import logging
import queue import queue
import time import time
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor
from typing import Optional, cast from typing import Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
@ -13,10 +13,20 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
)
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.base_node import UserFrom, node_classes
from extensions.ext_database import db from extensions.ext_database import db
from models.workflow import WorkflowType
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun") thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,6 +35,8 @@ logger = logging.getLogger(__name__)
class GraphEngine: class GraphEngine:
def __init__(self, tenant_id: str, def __init__(self, tenant_id: str,
app_id: str, app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str, user_id: str,
user_from: UserFrom, user_from: UserFrom,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
@ -33,13 +45,18 @@ class GraphEngine:
variable_pool: VariablePool, variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None: callbacks: list[BaseWorkflowCallback]) -> None:
self.graph = graph self.graph = graph
self.graph_runtime_state = GraphRuntimeState( self.init_params = GraphInitParams(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
user_id=user_id, user_id=user_id,
user_from=user_from, user_from=user_from,
invoke_from=invoke_from, invoke_from=invoke_from,
call_depth=call_depth, call_depth=call_depth
)
self.graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
start_at=time.perf_counter() start_at=time.perf_counter()
) )
@ -55,31 +72,31 @@ class GraphEngine:
# TODO convert generator to result # TODO convert generator to result
pass pass
def run(self) -> Generator: def run(self) -> Generator[GraphEngineEvent, None, None]:
# TODO trigger graph run start event # trigger graph run start event
yield GraphRunStartedEvent()
try: try:
# TODO run graph # run graph
rst = self._run(start_node_id=self.graph.root_node_id) generator = self._run(start_node_id=self.graph.root_node_id)
except GraphRunFailedError as e: for item in generator:
# TODO self._graph_run_failed( yield item
# error=e.error, if isinstance(item, NodeRunFailedEvent):
# callbacks=callbacks yield GraphRunFailedEvent(reason=item.reason)
# ) return
pass
# trigger graph run success event
yield GraphRunSucceededEvent()
except (GraphRunFailedError, NodeRunFailedError) as e:
yield GraphRunFailedEvent(reason=e.error)
return
except Exception as e: except Exception as e:
# TODO self._workflow_run_failed( yield GraphRunFailedEvent(reason=str(e))
# error=str(e), return
# callbacks=callbacks
# )
pass
# TODO trigger graph run success event def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
yield rst
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
next_node_id = start_node_id next_node_id = start_node_id
previous_node_id = None
while True: while True:
# max steps reached # max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps: if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
@ -92,10 +109,18 @@ class GraphEngine:
): ):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# run node TODO generator try:
yield from self._run_node(node_id=next_node_id) # run node
yield from self._run_node(
node_id=next_node_id,
previous_node_id=previous_node_id,
parallel_id=in_parallel_id
)
except Exception as e:
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
return
# todo if failed, break previous_node_id = next_node_id
# get next node ids # get next node ids
edge_mappings = self.graph.edge_mapping.get(next_node_id) edge_mappings = self.graph.edge_mapping.get(next_node_id)
@ -135,11 +160,11 @@ class GraphEngine:
# if nodes has no run conditions, parallel run all nodes # if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id) parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
if not parallel_id: if not parallel_id:
raise GraphRunFailedError('Node related parallel not found.') raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id) parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel: if not parallel:
raise GraphRunFailedError('Parallel not found.') raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
# run parallel nodes, run in new thread and use queue to get results # run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue() q: queue.Queue = queue.Queue()
@ -149,8 +174,9 @@ class GraphEngine:
for edge in edge_mappings: for edge in edge_mappings:
futures.append(thread_pool.submit( futures.append(thread_pool.submit(
self._run_parallel_node, self._run_parallel_node,
flask_app=current_app._get_current_object(), flask_app=current_app._get_current_object(), # type: ignore
parallel_start_node_id=edge.source_node_id, parallel_id=parallel_id,
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
q=q q=q
)) ))
@ -165,8 +191,9 @@ class GraphEngine:
except queue.Empty: except queue.Empty:
continue continue
for future in as_completed(futures): # not necessary
future.result() # for future in as_completed(futures):
# future.result()
# get final node id # get final node id
final_node_id = parallel.end_to_node_id final_node_id = parallel.end_to_node_id
@ -178,48 +205,61 @@ class GraphEngine:
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break break
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None: def _run_parallel_node(self,
flask_app: Flask,
parallel_id: str,
parallel_start_node_id: str,
q: queue.Queue) -> None:
""" """
Run parallel nodes Run parallel nodes
""" """
with flask_app.app_context(): with flask_app.app_context():
try: try:
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id) # run node
if not in_parallel_id: generator = self._run(
q.put(None)
return
# run node TODO generator
rst = self._run(
start_node_id=parallel_start_node_id, start_node_id=parallel_start_node_id,
in_parallel_id=in_parallel_id in_parallel_id=parallel_id
) )
if not rst: if generator:
q.put(None) for item in generator:
return
for item in rst:
q.put(item) q.put(item)
q.put(None)
except Exception: except Exception:
logger.exception("Unknown Error when generating in parallel") logger.exception("Unknown Error when generating in parallel")
finally: finally:
q.put(None)
db.session.remove() db.session.remove()
def _run_node(self, node_id: str) -> Generator: def _run_node(self,
node_id: str,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None
) -> Generator[GraphEngineEvent, None, None]:
""" """
Run node Run node
""" """
# get node config # get node config
node_config = self.graph.node_id_config_mapping.get(node_id) node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config: if not node_config:
raise GraphRunFailedError('Node not found.') raise GraphRunFailedError(f'Node {node_id} config not found.')
# todo convert to specific node # convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
# todo trigger node run start event # init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
# trigger node run start event
yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id)
db.session.close() db.session.close()
@ -229,27 +269,24 @@ class GraphEngine:
try: try:
# run node # run node
rst = node.run( generator = node_instance.run()
graph_runtime_state=self.graph_runtime_state,
graph=self.graph,
callbacks=self.callbacks
)
yield from rst yield from generator
# todo record state # todo record state
# trigger node run success event
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
except GenerateTaskStoppedException as e: except GenerateTaskStoppedException as e:
# TODO yield failed # trigger node run failed event
# todo trigger node run failed event yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
pass return
except Exception as e: except Exception as e:
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") # todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# TODO yield failed # trigger node run failed event
# todo trigger node run failed event yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
pass return
finally:
# todo trigger node run success event
db.session.close() db.session.close()
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
@ -265,3 +302,8 @@ class GraphEngine:
class GraphRunFailedError(Exception): class GraphRunFailedError(Exception):
def __init__(self, error: str): def __init__(self, error: str):
self.error = error self.error = error
class NodeRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -1,33 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from collections.abc import Generator
from typing import Optional from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iterable_node import IterableNodeMixin from core.workflow.nodes.iterable_node import IterableNodeMixin
from models.workflow import WorkflowType
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
class BaseNode(ABC): class BaseNode(ABC):
@ -36,64 +20,66 @@ class BaseNode(ABC):
tenant_id: str tenant_id: str
app_id: str app_id: str
workflow_type: WorkflowType
workflow_id: str workflow_id: str
user_id: str user_id: str
user_from: UserFrom user_from: UserFrom
invoke_from: InvokeFrom invoke_from: InvokeFrom
workflow_call_depth: int workflow_call_depth: int
graph: Graph
graph_runtime_state: GraphRuntimeState
previous_node_id: Optional[str] = None
node_id: str node_id: str
node_data: BaseNodeData node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback] def __init__(self,
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
config: dict, config: dict,
callbacks: list[BaseWorkflowCallback] = None, graph_init_params: GraphInitParams,
workflow_call_depth: int = 0) -> None: graph: Graph,
self.tenant_id = tenant_id graph_runtime_state: GraphRuntimeState,
self.app_id = app_id previous_node_id: Optional[str] = None) -> None:
self.workflow_id = workflow_id self.tenant_id = graph_init_params.tenant_id
self.user_id = user_id self.app_id = graph_init_params.app_id
self.user_from = user_from self.workflow_type = graph_init_params.workflow_type
self.invoke_from = invoke_from self.workflow_id = graph_init_params.workflow_id
self.workflow_call_depth = workflow_call_depth self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.node_id = config.get("id") node_id = config.get("id")
if not self.node_id: if not node_id:
raise ValueError("Node ID is required.") raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {})) self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod @abstractmethod
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self) \
-> NodeRunResult | Generator[RunEvent, None, None]:
""" """
Run node Run node
:param variable_pool: variable pool
:return: :return:
""" """
raise NotImplementedError raise NotImplementedError
def run(self, variable_pool: VariablePool) -> NodeRunResult: def run(self) -> Generator[RunEvent, None, None]:
""" """
Run node entry Run node entry
:param variable_pool: variable pool
:return: :return:
""" """
result = self._run( result = self._run()
variable_pool=variable_pool
)
self.node_run_result = result if isinstance(result, NodeRunResult):
return result yield RunCompletedEvent(
run_result=result
)
else:
yield from result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
""" """
@ -102,13 +88,13 @@ class BaseNode(ABC):
:param value_selector: value selector :param value_selector: value selector
:return: :return:
""" """
# TODO remove callbacks
if self.callbacks: if self.callbacks:
for callback in self.callbacks: for callback in self.callbacks:
callback.on_node_text_chunk( callback.on_node_text_chunk(
node_id=self.node_id, node_id=self.node_id,
text=text, text=text,
metadata={ metadata={
"node_type": self.node_type,
"value_selector": value_selector "value_selector": value_selector
} }
) )

View File

@ -0,0 +1,20 @@
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent

View File

@ -3,15 +3,16 @@ from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from typing import Optional, cast from typing import Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
@ -23,9 +24,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import ( from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate, LLMNodeCompletionModelPromptTemplate,
@ -43,13 +47,13 @@ class LLMNode(BaseNode):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
node_type = NodeType.LLM node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self) -> Generator[RunEvent, None, None]:
""" """
Run node Run node
:param variable_pool: variable pool
:return: :return:
""" """
node_data = cast(LLMNodeData, deepcopy(self.node_data)) node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
node_inputs = None node_inputs = None
process_data = None process_data = None
@ -76,10 +80,17 @@ class LLMNode(BaseNode):
node_inputs['#files#'] = [file.to_dict() for file in files] node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value # fetch context value
context = self._fetch_context(node_data, variable_pool) generator = self._fetch_context(node_data, variable_pool)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield NodeRunRetrieverResourceEvent(
retriever_resources=event.retriever_resources
)
if context: if context:
node_inputs['#context#'] = context node_inputs['#context#'] = context # type: ignore
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model) model_instance, model_config = self._fetch_model_config(node_data.model)
@ -90,7 +101,7 @@ class LLMNode(BaseNode):
# fetch prompt messages # fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data, node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) # type: ignore
if node_data.memory else None, if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs, inputs=inputs,
@ -109,26 +120,40 @@ class LLMNode(BaseNode):
} }
# handle invoke result # handle invoke result
result_text, usage = self._invoke_llm( generator = self._invoke_llm(
node_data_model=node_data.model, node_data_model=node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop stop=stop
) )
result_text = ''
usage = LLMUsage.empty_usage()
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
break
except Exception as e: except Exception as e:
return NodeRunResult( yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),
inputs=node_inputs, inputs=node_inputs,
process_data=process_data process_data=process_data
) )
)
return
outputs = { outputs = {
'text': result_text, 'text': result_text,
'usage': jsonable_encoder(usage) 'usage': jsonable_encoder(usage)
} }
return NodeRunResult( yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs, inputs=node_inputs,
process_data=process_data, process_data=process_data,
@ -139,11 +164,13 @@ class LLMNode(BaseNode):
NodeRunMetadataKey.CURRENCY: usage.currency NodeRunMetadataKey.CURRENCY: usage.currency
} }
) )
)
def _invoke_llm(self, node_data_model: ModelConfig, def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]: stop: Optional[list[str]] = None) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
""" """
Invoke large language model Invoke large language model
:param node_data_model: node data model :param node_data_model: node data model
@ -163,30 +190,41 @@ class LLMNode(BaseNode):
) )
# handle invoke result # handle invoke result
text, usage = self._handle_invoke_result( generator = self._handle_invoke_result(
invoke_result=invoke_result invoke_result=invoke_result
) )
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
usage = event.usage
# deduct quota # deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
:return: :return:
""" """
if isinstance(invoke_result, LLMResult):
return
model = None model = None
prompt_messages = [] prompt_messages: list[PromptMessage] = []
full_text = '' full_text = ''
usage = None usage = None
for result in invoke_result: for result in invoke_result:
text = result.delta.message.content text = result.delta.message.content
full_text += text full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) yield RunStreamChunkEvent(
chunk_content=text,
from_variable_selector=[self.node_id, 'text']
)
if not model: if not model:
model = result.model model = result.model
@ -200,7 +238,10 @@ class LLMNode(BaseNode):
if not usage: if not usage:
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
return full_text, usage yield ModelInvokeCompleted(
text=full_text,
usage=usage
)
def _transform_chat_messages(self, def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
@ -213,13 +254,13 @@ class LLMNode(BaseNode):
""" """
if isinstance(messages, LLMNodeCompletionModelPromptTemplate): if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2': if messages.edition_type == 'jinja2' and messages.jinja2_text:
messages.text = messages.jinja2_text messages.text = messages.jinja2_text
return messages return messages
for message in messages: for message in messages:
if message.edition_type == 'jinja2': if message.edition_type == 'jinja2' and message.jinja2_text:
message.text = message.jinja2_text message.text = message.jinja2_text
return messages return messages
@ -319,7 +360,7 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value inputs[variable_selector.variable] = variable_value
return inputs return inputs # type: ignore
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
""" """
@ -337,7 +378,7 @@ class LLMNode(BaseNode):
return files return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
""" """
Fetch context Fetch context
:param node_data: node data :param node_data: node data
@ -353,7 +394,10 @@ class LLMNode(BaseNode):
context_value = variable_pool.get_variable_value(node_data.context.variable_selector) context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
if context_value: if context_value:
if isinstance(context_value, str): if isinstance(context_value, str):
return context_value yield RunRetrieverResourceEvent(
retriever_resources=[],
context=context_value
)
elif isinstance(context_value, list): elif isinstance(context_value, list):
context_str = '' context_str = ''
original_retriever_resource = [] original_retriever_resource = []
@ -370,17 +414,10 @@ class LLMNode(BaseNode):
if retriever_resource: if retriever_resource:
original_retriever_resource.append(retriever_resource) original_retriever_resource.append(retriever_resource)
if self.callbacks and original_retriever_resource: yield RunRetrieverResourceEvent(
for callback in self.callbacks: retriever_resources=original_retriever_resource,
callback.on_event( context=context_str.strip()
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
) )
)
return context_str.strip()
return None
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
""" """
@ -561,7 +598,8 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str): if not isinstance(prompt_message.content, str):
prompt_message_content = [] prompt_message_content = []
for content_item in prompt_message.content: for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config # Override vision config if LLM node has vision config
if vision_detail: if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@ -633,13 +671,13 @@ class LLMNode(BaseNode):
db.session.commit() db.session.commit()
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
:param node_data: node data :param node_data: node data
:return: :return:
""" """
node_data = cast(LLMNodeData, node_data)
prompt_template = node_data.prompt_template prompt_template = node_data.prompt_template
variable_selectors = [] variable_selectors = []
@ -727,3 +765,11 @@ class LLMNode(BaseNode):
} }
} }
} }
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage

View File

@ -5,6 +5,7 @@ from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -16,7 +17,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import ( from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData _node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data) node_data = cast(QuestionClassifierNodeData, node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables # extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
@ -62,12 +64,21 @@ class QuestionClassifierNode(LLMNode):
) )
# handle invoke result # handle invoke result
result_text, usage = self._invoke_llm( generator = self._invoke_llm(
node_data_model=node_data.model, node_data_model=node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop stop=stop
) )
result_text = ''
usage = LLMUsage.empty_usage()
for event in generator:
if isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
break
category_name = node_data.classes[0].name category_name = node_data.classes[0].name
category_id = node_data.classes[0].id category_id = node_data.classes[0].id
try: try:

View File

@ -17,47 +17,17 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from extensions.ext_database import db from extensions.ext_database import db
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowType,
) )
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -115,6 +85,8 @@ class WorkflowEntry:
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
app_id=workflow.app_id, app_id=workflow.app_id,
workflow_type=WorkflowType.value_of(workflow.type),
workflow_id=workflow.id,
user_id=user_id, user_id=user_id,
user_from=user_from, user_from=user_from,
invoke_from=invoke_from, invoke_from=invoke_from,