mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 02:35:56 +08:00
fix: unit tests in workflow
This commit is contained in:
parent
702df31db7
commit
c5192650fb
@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
|
@ -290,17 +290,16 @@ class Graph(BaseModel):
|
||||
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
|
||||
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
|
||||
|
||||
if not parent_parallel_id:
|
||||
raise Exception(f"Parent parallel id not found for node ids {parallel_node_ids}")
|
||||
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
raise Exception(f"Parent parallel {parent_parallel_id} not found")
|
||||
parent_parallel = None
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
raise Exception(f"Parent parallel {parent_parallel_id} not found")
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
|
@ -96,6 +96,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
@ -121,6 +122,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
|
@ -267,7 +267,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 21
|
||||
assert len(items) == 18
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == 'start'
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -62,6 +63,7 @@ def test_execute_answer():
|
||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@ -38,11 +39,13 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
||||
parallel = graph.parallel_mapping.get(parallel_id)
|
||||
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
node_config = graph.node_id_config_mapping[next_node_id]
|
||||
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
@ -55,6 +58,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
||||
length = int(next_node_id[-1])
|
||||
for i in range(0, length):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
@ -68,6 +72,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_execution_id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -161,6 +162,7 @@ def test_run():
|
||||
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -65,6 +66,7 @@ def test_execute_if_else_result_true():
|
||||
pool.add(['start', 'not_null'], '1212')
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
@ -254,6 +256,7 @@ def test_execute_if_else_result_false():
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
|
@ -1,17 +1,62 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
DEFAULT_NODE_ID = 'node_id'
|
||||
|
||||
|
||||
def test_overwrite_string_variable():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
conversation_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
@ -24,13 +69,27 @@ def test_overwrite_string_variable():
|
||||
value='the second value',
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
variable_pool.add(
|
||||
[DEFAULT_NODE_ID, input_variable.name],
|
||||
input_variable,
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
@ -41,19 +100,8 @@ def test_overwrite_string_variable():
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
variable_pool.add(
|
||||
[DEFAULT_NODE_ID, input_variable.name],
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
@ -63,6 +111,46 @@ def test_overwrite_string_variable():
|
||||
|
||||
|
||||
def test_append_variable_to_array():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
@ -75,23 +163,6 @@ def test_append_variable_to_array():
|
||||
value='the second value',
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.APPEND.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
@ -103,8 +174,26 @@ def test_append_variable_to_array():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.APPEND.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
node.run(variable_pool)
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
@ -113,19 +202,67 @@ def test_append_variable_to_array():
|
||||
|
||||
|
||||
def test_clear_array():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-assigner-target",
|
||||
"source": "start",
|
||||
"target": "assigner",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "assigner",
|
||||
},
|
||||
"id": "assigner"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
@ -136,14 +273,9 @@ def test_clear_array():
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node.run(variable_pool)
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
|
||||
list(node.run())
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
assert got is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user