mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 03:35:57 +08:00
graph engine implement
This commit is contained in:
parent
821e09b259
commit
00fb23d0c9
@ -3,6 +3,20 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -41,6 +55,25 @@ class NodeType(Enum):
|
||||
raise ValueError(f'invalid node type value {value}')
|
||||
|
||||
|
||||
node_classes = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||
}
|
||||
|
||||
|
||||
class SystemVariable(Enum):
|
||||
"""
|
||||
System Variables.
|
||||
@ -90,3 +123,23 @@ class NodeRunResult(BaseModel):
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
|
116
api/core/workflow/graph_engine/entities/event.py
Normal file
116
api/core/workflow/graph_engine/entities/event.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunBackToRootEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
node_id: str = Field(..., description="node id")
|
||||
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
|
||||
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
reason: str = Field("", description="failed reason")
|
||||
|
||||
@model_validator(mode='before')
|
||||
def init_reason(cls, values: dict) -> dict:
|
||||
if not values.get("reason"):
|
||||
values["reason"] = values.get("run_result").error or "Unknown error"
|
||||
return values
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
|
||||
|
||||
class ParallelRunStartedEvent(BaseParallelEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRunSucceededEvent(BaseParallelEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRunFailedEvent(BaseParallelEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration id")
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
pass
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
pass
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent
|
@ -8,48 +8,37 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str
|
||||
"""source node id"""
|
||||
|
||||
target_node_id: str
|
||||
"""target node id"""
|
||||
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""condition to run the edge"""
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = Field(None, description="run condition")
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random uuid parallel id"""
|
||||
|
||||
start_from_node_id: str
|
||||
"""start from node id"""
|
||||
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if exists"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
|
||||
end_to_node_id: Optional[str] = Field(None, description="end to node id")
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str
|
||||
"""root node id of the graph"""
|
||||
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
"""graph node ids"""
|
||||
|
||||
node_id_config_mapping: dict[str, dict] = Field(default_factory=list)
|
||||
"""node configs mapping (node id: node config)"""
|
||||
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
||||
"""graph edge mapping (source node id: edges)"""
|
||||
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
|
||||
"""graph parallel mapping (parallel id: parallel)"""
|
||||
|
||||
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
|
||||
"""graph node parallel mapping (node id: parallel id)"""
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
|
17
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
17
api/core/workflow/graph_engine/entities/graph_init_params.py
Normal file
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
@ -1,25 +1,14 @@
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# init params
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
call_depth: int
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
|
||||
variable_pool: VariablePool
|
||||
start_at: float = Field(..., description="start time")
|
||||
total_tokens: int = Field(0, description="total tokens")
|
||||
node_run_steps: int = Field(0, description="node run steps")
|
||||
|
||||
start_at: float
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -13,10 +13,20 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.base_node import UserFrom, node_classes
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,6 +35,8 @@ logger = logging.getLogger(__name__)
|
||||
class GraphEngine:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
@ -33,13 +45,18 @@ class GraphEngine:
|
||||
variable_pool: VariablePool,
|
||||
callbacks: list[BaseWorkflowCallback]) -> None:
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
@ -55,31 +72,31 @@ class GraphEngine:
|
||||
# TODO convert generator to result
|
||||
pass
|
||||
|
||||
def run(self) -> Generator:
|
||||
# TODO trigger graph run start event
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
# TODO run graph
|
||||
rst = self._run(start_node_id=self.graph.root_node_id)
|
||||
except GraphRunFailedError as e:
|
||||
# TODO self._graph_run_failed(
|
||||
# error=e.error,
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
pass
|
||||
# run graph
|
||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
||||
for item in generator:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(reason=item.reason)
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent()
|
||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
||||
yield GraphRunFailedEvent(reason=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
# TODO self._workflow_run_failed(
|
||||
# error=str(e),
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
pass
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
return
|
||||
|
||||
# TODO trigger graph run success event
|
||||
|
||||
yield rst
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
next_node_id = start_node_id
|
||||
previous_node_id = None
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
@ -92,10 +109,18 @@ class GraphEngine:
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# run node TODO generator
|
||||
yield from self._run_node(node_id=next_node_id)
|
||||
try:
|
||||
# run node
|
||||
yield from self._run_node(
|
||||
node_id=next_node_id,
|
||||
previous_node_id=previous_node_id,
|
||||
parallel_id=in_parallel_id
|
||||
)
|
||||
except Exception as e:
|
||||
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
|
||||
return
|
||||
|
||||
# todo if failed, break
|
||||
previous_node_id = next_node_id
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
@ -135,11 +160,11 @@ class GraphEngine:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
|
||||
if not parallel_id:
|
||||
raise GraphRunFailedError('Node related parallel not found.')
|
||||
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError('Parallel not found.')
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
@ -149,8 +174,9 @@ class GraphEngine:
|
||||
for edge in edge_mappings:
|
||||
futures.append(thread_pool.submit(
|
||||
self._run_parallel_node,
|
||||
flask_app=current_app._get_current_object(),
|
||||
parallel_start_node_id=edge.source_node_id,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
|
||||
q=q
|
||||
))
|
||||
|
||||
@ -165,8 +191,9 @@ class GraphEngine:
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
# not necessary
|
||||
# for future in as_completed(futures):
|
||||
# future.result()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
@ -178,48 +205,61 @@ class GraphEngine:
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None:
|
||||
def _run_parallel_node(self,
|
||||
flask_app: Flask,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
q: queue.Queue) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id)
|
||||
if not in_parallel_id:
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
# run node TODO generator
|
||||
rst = self._run(
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=in_parallel_id
|
||||
in_parallel_id=parallel_id
|
||||
)
|
||||
|
||||
if not rst:
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
for item in rst:
|
||||
if generator:
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
|
||||
q.put(None)
|
||||
except Exception:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
finally:
|
||||
q.put(None)
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self, node_id: str) -> Generator:
|
||||
def _run_node(self,
|
||||
node_id: str,
|
||||
previous_node_id: Optional[str] = None,
|
||||
parallel_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# get node config
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError('Node not found.')
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
|
||||
# todo convert to specific node
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
# todo trigger node run start event
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id
|
||||
)
|
||||
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id)
|
||||
|
||||
db.session.close()
|
||||
|
||||
@ -229,27 +269,24 @@ class GraphEngine:
|
||||
|
||||
try:
|
||||
# run node
|
||||
rst = node.run(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
graph=self.graph,
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
generator = node_instance.run()
|
||||
|
||||
yield from rst
|
||||
yield from generator
|
||||
|
||||
# todo record state
|
||||
|
||||
# trigger node run success event
|
||||
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
|
||||
except GenerateTaskStoppedException as e:
|
||||
# TODO yield failed
|
||||
# todo trigger node run failed event
|
||||
pass
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
# TODO yield failed
|
||||
# todo trigger node run failed event
|
||||
pass
|
||||
|
||||
# todo trigger node run success event
|
||||
|
||||
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
# trigger node run failed event
|
||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||
return
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
@ -265,3 +302,8 @@ class GraphEngine:
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
|
||||
class NodeRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
@ -1,33 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
"""
|
||||
User from
|
||||
"""
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "UserFrom":
|
||||
"""
|
||||
Value of
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
for item in cls:
|
||||
if item.value == value:
|
||||
return item
|
||||
raise ValueError(f"Invalid value: {value}")
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
@ -36,64 +20,66 @@ class BaseNode(ABC):
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_type: WorkflowType
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
workflow_call_depth: int
|
||||
graph: Graph
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
previous_node_id: Optional[str] = None
|
||||
|
||||
node_id: str
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
def __init__(self,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.workflow_id = workflow_id
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
graph_init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_node_id: Optional[str] = None) -> None:
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_type = graph_init_params.workflow_type
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.previous_node_id = previous_node_id
|
||||
|
||||
self.node_id = config.get("id")
|
||||
if not self.node_id:
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) \
|
||||
-> NodeRunResult | Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def run(self) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
result = self._run()
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
if isinstance(result, NodeRunResult):
|
||||
yield RunCompletedEvent(
|
||||
run_result=result
|
||||
)
|
||||
else:
|
||||
yield from result
|
||||
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
||||
"""
|
||||
@ -102,13 +88,13 @@ class BaseNode(ABC):
|
||||
:param value_selector: value selector
|
||||
:return:
|
||||
"""
|
||||
# TODO remove callbacks
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text,
|
||||
metadata={
|
||||
"node_type": self.node_type,
|
||||
"value_selector": value_selector
|
||||
}
|
||||
)
|
||||
|
20
api/core/workflow/nodes/event.py
Normal file
20
api/core/workflow/nodes/event.py
Normal file
@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class RunStreamChunkEvent(BaseModel):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent
|
@ -3,15 +3,16 @@ from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
@ -23,9 +24,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
@ -43,13 +47,13 @@ class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(LLMNodeData, deepcopy(self.node_data))
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
node_inputs = None
|
||||
process_data = None
|
||||
@ -76,10 +80,17 @@ class LLMNode(BaseNode):
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
generator = self._fetch_context(node_data, variable_pool)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
retriever_resources=event.retriever_resources
|
||||
)
|
||||
|
||||
if context:
|
||||
node_inputs['#context#'] = context
|
||||
node_inputs['#context#'] = context # type: ignore
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
@ -90,7 +101,7 @@ class LLMNode(BaseNode):
|
||||
# fetch prompt messages
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) # type: ignore
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
@ -109,26 +120,40 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
if isinstance(event, RunStreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompleted):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
break
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {
|
||||
'text': result_text,
|
||||
'usage': jsonable_encoder(usage)
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
@ -139,11 +164,13 @@ class LLMNode(BaseNode):
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: list[str]) -> tuple[str, LLMUsage]:
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
@ -163,30 +190,41 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
text, usage = self._handle_invoke_result(
|
||||
generator = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
yield event
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
usage = event.usage
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:return:
|
||||
"""
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
return
|
||||
|
||||
model = None
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@ -200,7 +238,10 @@ class LLMNode(BaseNode):
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
return full_text, usage
|
||||
yield ModelInvokeCompleted(
|
||||
text=full_text,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def _transform_chat_messages(self,
|
||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
@ -213,13 +254,13 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
|
||||
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
|
||||
if messages.edition_type == 'jinja2':
|
||||
if messages.edition_type == 'jinja2' and messages.jinja2_text:
|
||||
messages.text = messages.jinja2_text
|
||||
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if message.edition_type == 'jinja2':
|
||||
if message.edition_type == 'jinja2' and message.jinja2_text:
|
||||
message.text = message.jinja2_text
|
||||
|
||||
return messages
|
||||
@ -319,7 +360,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
return inputs
|
||||
return inputs # type: ignore
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
"""
|
||||
@ -337,7 +378,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
|
||||
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Fetch context
|
||||
:param node_data: node data
|
||||
@ -353,7 +394,10 @@ class LLMNode(BaseNode):
|
||||
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
return context_value
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=[],
|
||||
context=context_value
|
||||
)
|
||||
elif isinstance(context_value, list):
|
||||
context_str = ''
|
||||
original_retriever_resource = []
|
||||
@ -370,17 +414,10 @@ class LLMNode(BaseNode):
|
||||
if retriever_resource:
|
||||
original_retriever_resource.append(retriever_resource)
|
||||
|
||||
if self.callbacks and original_retriever_resource:
|
||||
for callback in self.callbacks:
|
||||
callback.on_event(
|
||||
event=QueueRetrieverResourcesEvent(
|
||||
retriever_resources=original_retriever_resource
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=original_retriever_resource,
|
||||
context=context_str.strip()
|
||||
)
|
||||
)
|
||||
|
||||
return context_str.strip()
|
||||
|
||||
return None
|
||||
|
||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
||||
"""
|
||||
@ -561,7 +598,8 @@ class LLMNode(BaseNode):
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
|
||||
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
|
||||
content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config
|
||||
if vision_detail:
|
||||
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
|
||||
@ -633,13 +671,13 @@ class LLMNode(BaseNode):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
@ -727,3 +765,11 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, Union, cast
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -16,7 +17,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
|
||||
@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
|
||||
@ -62,12 +64,21 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._invoke_llm(
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
result_text = ''
|
||||
usage = LLMUsage.empty_usage()
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompleted):
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
break
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
try:
|
||||
|
@ -17,47 +17,17 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes
|
||||
from core.workflow.nodes.iteration.entities import IterationState
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -115,6 +85,8 @@ class WorkflowEntry:
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
|
Loading…
x
Reference in New Issue
Block a user