graph engine implement

This commit is contained in:
takatost 2024-07-15 23:40:02 +08:00
parent 821e09b259
commit 00fb23d0c9
11 changed files with 511 additions and 270 deletions

View File

@ -3,6 +3,20 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from models.workflow import WorkflowNodeExecutionStatus
@ -41,6 +55,25 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}')
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
class SystemVariable(Enum):
"""
System Variables.
@ -90,3 +123,23 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@ -0,0 +1,116 @@
from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.workflow.entities.node_entities import NodeRunResult
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunBackToRootEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
pass
class GraphRunFailedEvent(BaseGraphEvent):
reason: str = Field(..., description="failed reason")
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
node_id: str = Field(..., description="node id")
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class NodeRunStartedEvent(BaseNodeEvent):
pass
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
class NodeRunSucceededEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
class NodeRunFailedEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
reason: str = Field("", description="failed reason")
@model_validator(mode='before')
def init_reason(cls, values: dict) -> dict:
if not values.get("reason"):
values["reason"] = values.get("run_result").error or "Unknown error"
return values
###########################################
# Parallel Events
###########################################
class BaseParallelEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
class ParallelRunStartedEvent(BaseParallelEvent):
pass
class ParallelRunSucceededEvent(BaseParallelEvent):
pass
class ParallelRunFailedEvent(BaseParallelEvent):
reason: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration id")
class IterationRunStartedEvent(BaseIterationEvent):
pass
class IterationRunSucceededEvent(BaseIterationEvent):
pass
class IterationRunFailedEvent(BaseIterationEvent):
reason: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent

View File

@ -8,48 +8,37 @@ from core.workflow.graph_engine.entities.run_condition import RunCondition
class GraphEdge(BaseModel):
source_node_id: str
"""source node id"""
target_node_id: str
"""target node id"""
run_condition: Optional[RunCondition] = None
"""condition to run the edge"""
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = Field(None, description="run condition")
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random uuid parallel id"""
start_from_node_id: str
"""start from node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if exists"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
end_to_node_id: Optional[str] = Field(None, description="end to node id")
class Graph(BaseModel):
root_node_id: str
"""root node id of the graph"""
node_ids: list[str] = Field(default_factory=list)
"""graph node ids"""
node_id_config_mapping: dict[str, dict] = Field(default_factory=list)
"""node configs mapping (node id: node config)"""
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
"""graph edge mapping (source node id: edges)"""
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
"""graph parallel mapping (parallel id: parallel)"""
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
"""graph node parallel mapping (node id: parallel id)"""
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list,
description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
@classmethod
def init(cls,

View File

@ -0,0 +1,17 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@ -1,25 +1,14 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel):
# init params
tenant_id: str
app_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
call_depth: int
variable_pool: VariablePool = Field(..., description="variable pool")
variable_pool: VariablePool
start_at: float = Field(..., description="start time")
total_tokens: int = Field(0, description="total tokens")
node_run_steps: int = Field(0, description="node run steps")
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")

View File

@ -2,7 +2,7 @@ import logging
import queue
import time
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, cast
from flask import Flask, current_app
@ -13,10 +13,20 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.base_node import UserFrom, node_classes
from extensions.ext_database import db
from models.workflow import WorkflowType
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__)
@ -25,6 +35,8 @@ logger = logging.getLogger(__name__)
class GraphEngine:
def __init__(self, tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
@ -33,13 +45,18 @@ class GraphEngine:
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
self.graph = graph
self.graph_runtime_state = GraphRuntimeState(
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
call_depth=call_depth
)
self.graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
)
@ -55,31 +72,31 @@ class GraphEngine:
# TODO convert generator to result
pass
def run(self) -> Generator:
# TODO trigger graph run start event
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
try:
# TODO run graph
rst = self._run(start_node_id=self.graph.root_node_id)
except GraphRunFailedError as e:
# TODO self._graph_run_failed(
# error=e.error,
# callbacks=callbacks
# )
pass
# run graph
generator = self._run(start_node_id=self.graph.root_node_id)
for item in generator:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(reason=item.reason)
return
# trigger graph run success event
yield GraphRunSucceededEvent()
except (GraphRunFailedError, NodeRunFailedError) as e:
yield GraphRunFailedEvent(reason=e.error)
return
except Exception as e:
# TODO self._workflow_run_failed(
# error=str(e),
# callbacks=callbacks
# )
pass
yield GraphRunFailedEvent(reason=str(e))
return
# TODO trigger graph run success event
yield rst
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
next_node_id = start_node_id
previous_node_id = None
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
@ -92,10 +109,18 @@ class GraphEngine:
):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# run node TODO generator
yield from self._run_node(node_id=next_node_id)
try:
# run node
yield from self._run_node(
node_id=next_node_id,
previous_node_id=previous_node_id,
parallel_id=in_parallel_id
)
except Exception as e:
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
return
# todo if failed, break
previous_node_id = next_node_id
# get next node ids
edge_mappings = self.graph.edge_mapping.get(next_node_id)
@ -135,11 +160,11 @@ class GraphEngine:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
if not parallel_id:
raise GraphRunFailedError('Node related parallel not found.')
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError('Parallel not found.')
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
@ -149,8 +174,9 @@ class GraphEngine:
for edge in edge_mappings:
futures.append(thread_pool.submit(
self._run_parallel_node,
flask_app=current_app._get_current_object(),
parallel_start_node_id=edge.source_node_id,
flask_app=current_app._get_current_object(), # type: ignore
parallel_id=parallel_id,
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
q=q
))
@ -165,8 +191,9 @@ class GraphEngine:
except queue.Empty:
continue
for future in as_completed(futures):
future.result()
# not necessary
# for future in as_completed(futures):
# future.result()
# get final node id
final_node_id = parallel.end_to_node_id
@ -178,48 +205,61 @@ class GraphEngine:
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None:
def _run_parallel_node(self,
flask_app: Flask,
parallel_id: str,
parallel_start_node_id: str,
q: queue.Queue) -> None:
"""
Run parallel nodes
"""
with flask_app.app_context():
try:
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id)
if not in_parallel_id:
q.put(None)
return
# run node TODO generator
rst = self._run(
# run node
generator = self._run(
start_node_id=parallel_start_node_id,
in_parallel_id=in_parallel_id
in_parallel_id=parallel_id
)
if not rst:
q.put(None)
return
for item in rst:
if generator:
for item in generator:
q.put(item)
q.put(None)
except Exception:
logger.exception("Unknown Error when generating in parallel")
finally:
q.put(None)
db.session.remove()
def _run_node(self, node_id: str) -> Generator:
def _run_node(self,
node_id: str,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
"""
# get node config
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError('Node not found.')
raise GraphRunFailedError(f'Node {node_id} config not found.')
# todo convert to specific node
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
# todo trigger node run start event
# init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
# trigger node run start event
yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id)
db.session.close()
@ -229,27 +269,24 @@ class GraphEngine:
try:
# run node
rst = node.run(
graph_runtime_state=self.graph_runtime_state,
graph=self.graph,
callbacks=self.callbacks
)
generator = node_instance.run()
yield from rst
yield from generator
# todo record state
# trigger node run success event
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
except GenerateTaskStoppedException as e:
# TODO yield failed
# todo trigger node run failed event
pass
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
except Exception as e:
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# TODO yield failed
# todo trigger node run failed event
pass
# todo trigger node run success event
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
finally:
db.session.close()
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
@ -265,3 +302,8 @@ class GraphEngine:
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error
class NodeRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -1,33 +1,17 @@
from abc import ABC, abstractmethod
from enum import Enum
from collections.abc import Generator
from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iterable_node import IterableNodeMixin
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
from models.workflow import WorkflowType
class BaseNode(ABC):
@ -36,64 +20,66 @@ class BaseNode(ABC):
tenant_id: str
app_id: str
workflow_type: WorkflowType
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
graph: Graph
graph_runtime_state: GraphRuntimeState
previous_node_id: Optional[str] = None
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback]
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
def __init__(self,
config: dict,
callbacks: list[BaseWorkflowCallback] = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None) -> None:
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.node_id = config.get("id")
if not self.node_id:
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) \
-> NodeRunResult | Generator[RunEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> NodeRunResult:
def run(self) -> Generator[RunEvent, None, None]:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
result = self._run(
variable_pool=variable_pool
)
result = self._run()
self.node_run_result = result
return result
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(
run_result=result
)
else:
yield from result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
@ -102,13 +88,13 @@ class BaseNode(ABC):
:param value_selector: value selector
:return:
"""
# TODO remove callbacks
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"node_type": self.node_type,
"value_selector": value_selector
}
)

View File

@ -0,0 +1,20 @@
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent

View File

@ -3,15 +3,16 @@ from collections.abc import Generator
from copy import deepcopy
from typing import Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
@ -23,9 +24,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
@ -43,13 +47,13 @@ class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> Generator[RunEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
node_inputs = None
process_data = None
@ -76,10 +80,17 @@ class LLMNode(BaseNode):
node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value
context = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data, variable_pool)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield NodeRunRetrieverResourceEvent(
retriever_resources=event.retriever_resources
)
if context:
node_inputs['#context#'] = context
node_inputs['#context#'] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
@ -90,7 +101,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) # type: ignore
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@ -109,26 +120,40 @@ class LLMNode(BaseNode):
}
# handle invoke result
result_text, usage = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
break
except Exception as e:
return NodeRunResult(
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
)
)
return
outputs = {
'text': result_text,
'usage': jsonable_encoder(usage)
}
return NodeRunResult(
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
@ -139,11 +164,13 @@ class LLMNode(BaseNode):
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
stop: Optional[list[str]] = None) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -163,30 +190,41 @@ class LLMNode(BaseNode):
)
# handle invoke result
text, usage = self._handle_invoke_result(
generator = self._handle_invoke_result(
invoke_result=invoke_result
)
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
if isinstance(invoke_result, LLMResult):
return
model = None
prompt_messages = []
prompt_messages: list[PromptMessage] = []
full_text = ''
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
yield RunStreamChunkEvent(
chunk_content=text,
from_variable_selector=[self.node_id, 'text']
)
if not model:
model = result.model
@ -200,7 +238,10 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
yield ModelInvokeCompleted(
text=full_text,
usage=usage
)
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
@ -213,13 +254,13 @@ class LLMNode(BaseNode):
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
if messages.edition_type == 'jinja2' and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == 'jinja2':
if message.edition_type == 'jinja2' and message.jinja2_text:
message.text = message.jinja2_text
return messages
@ -319,7 +360,7 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value
return inputs
return inputs # type: ignore
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
"""
@ -337,7 +378,7 @@ class LLMNode(BaseNode):
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
@ -353,7 +394,10 @@ class LLMNode(BaseNode):
context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
return context_value
yield RunRetrieverResourceEvent(
retriever_resources=[],
context=context_value
)
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
@ -370,17 +414,10 @@ class LLMNode(BaseNode):
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks and original_retriever_resource:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip()
)
)
return context_str.strip()
return None
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
@ -561,7 +598,8 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@ -633,13 +671,13 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = cast(LLMNodeData, node_data)
prompt_template = node_data.prompt_template
variable_selectors = []
@ -727,3 +765,11 @@ class LLMNode(BaseNode):
}
}
}
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage

View File

@ -5,6 +5,7 @@ from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
@ -16,7 +17,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector)
@ -62,12 +64,21 @@ class QuestionClassifierNode(LLMNode):
)
# handle invoke result
result_text, usage = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
for event in generator:
if isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:

View File

@ -17,47 +17,17 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes
from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
WorkflowNodeExecutionStatus,
WorkflowType,
)
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
logger = logging.getLogger(__name__)
@ -115,6 +85,8 @@ class WorkflowEntry:
graph_engine = GraphEngine(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_type=WorkflowType.value_of(workflow.type),
workflow_id=workflow.id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,