fix: unit tests in workflow

This commit is contained in:
takatost 2024-08-15 23:47:59 +08:00
parent 702df31db7
commit c5192650fb
9 changed files with 206 additions and 62 deletions

View File

@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,

View File

@ -290,17 +290,16 @@ class Graph(BaseModel):
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
if not parent_parallel_id:
raise Exception(f"Parent parallel id not found for node ids {parallel_node_ids}")
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
raise Exception(f"Parent parallel {parent_parallel_id} not found")
parent_parallel = None
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
raise Exception(f"Parent parallel {parent_parallel_id} not found")
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
)
parallel_mapping[parallel.id] = parallel

View File

@ -96,6 +96,7 @@ class AnswerStreamProcessor(StreamProcessor):
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
@ -121,6 +122,7 @@ class AnswerStreamProcessor(StreamProcessor):
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,

View File

@ -267,7 +267,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
]:
assert item.parallel_id is not None
assert len(items) == 21
assert len(items) == 18
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'

View File

@ -1,4 +1,5 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
@ -62,6 +63,7 @@ def test_execute_answer():
pool.add(['llm', 'text'], 'You are a helpful AI.')
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(

View File

@ -1,3 +1,4 @@
import uuid
from collections.abc import Generator
from datetime import datetime, timezone
@ -38,11 +39,13 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None
node_execution_id = str(uuid.uuid4())
node_config = graph.node_id_config_mapping[next_node_id]
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
yield NodeRunStartedEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
@ -55,6 +58,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
length = int(next_node_id[-1])
for i in range(0, length):
yield NodeRunStreamChunkEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
@ -68,6 +72,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,

View File

@ -1,4 +1,5 @@
import time
import uuid
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
@ -161,6 +162,7 @@ def test_run():
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(

View File

@ -1,4 +1,5 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
@ -65,6 +66,7 @@ def test_execute_if_else_result_true():
pool.add(['start', 'not_null'], '1212')
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
@ -254,6 +256,7 @@ def test_execute_if_else_result_false():
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(

View File

@ -1,17 +1,62 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
from models.workflow import WorkflowType
DEFAULT_NODE_ID = 'node_id'
def test_overwrite_string_variable():
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "assigner",
},
"id": "assigner"
},
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
conversation_variable = StringVariable(
id=str(uuid4()),
name='test_conversation_variable',
@ -24,13 +69,27 @@ def test_overwrite_string_variable():
value='the second value',
)
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
),
config={
'id': 'node_id',
'data': {
@ -41,19 +100,8 @@ def test_overwrite_string_variable():
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
variable_pool.add(
[DEFAULT_NODE_ID, input_variable.name],
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
node.run(variable_pool)
list(node.run())
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
@ -63,6 +111,46 @@ def test_overwrite_string_variable():
def test_append_variable_to_array():
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "assigner",
},
"id": "assigner"
},
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
@ -75,23 +163,6 @@ def test_append_variable_to_array():
value='the second value',
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.APPEND.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
@ -103,8 +174,26 @@ def test_append_variable_to_array():
input_variable,
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
),
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.APPEND.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
},
},
)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
node.run(variable_pool)
list(node.run())
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
@ -113,19 +202,67 @@ def test_append_variable_to_array():
def test_clear_array():
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "assigner",
},
"id": "assigner"
},
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
),
config={
'id': 'node_id',
'data': {
@ -136,14 +273,9 @@ def test_clear_array():
},
)
variable_pool = VariablePool(
system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node.run(variable_pool)
with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run:
list(node.run())
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
assert got is not None