refactor(iteration_node): use Sequence and Mapping in parameters (#11483)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-09 15:41:20 +08:00 committed by GitHub
parent c3c6a48059
commit 537068cfde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 23 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
@ -140,8 +141,8 @@ class BaseIterationEvent(GraphEngineEvent):
class IterationRunStartedEvent(BaseIterationEvent): class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at") start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
@ -153,18 +154,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent): class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at") start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[Mapping[str, Any]] = None
steps: int = 0 steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent): class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at") start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[Mapping[str, Any]] = None
steps: int = 0 steps: int = 0
error: str = Field(..., description="failed reason") error: str = Field(..., description="failed reason")

View File

@ -167,17 +167,17 @@ class IterationNode(BaseNode[IterationNodeData]):
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit( future: Future = thread_pool.submit(
self._run_single_iter_parallel, self._run_single_iter_parallel,
current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
q, q=q,
iterator_list_value, iterator_list_value=iterator_list_value,
inputs, inputs=inputs,
outputs, outputs=outputs,
start_at, start_at=start_at,
graph_engine, graph_engine=graph_engine,
iteration_graph, iteration_graph=iteration_graph,
index, index=index,
item, item=item,
iter_run_map, iter_run_map=iter_run_map,
) )
future.add_done_callback(thread_pool.task_done_callback) future.add_done_callback(thread_pool.task_done_callback)
futures.append(future) futures.append(future)
@ -370,9 +370,9 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter( def _run_single_iter(
self, self,
*, *,
iterator_list_value: list[str], iterator_list_value: Sequence[str],
variable_pool: VariablePool, variable_pool: VariablePool,
inputs: dict[str, list], inputs: Mapping[str, list],
outputs: list, outputs: list,
start_at: datetime, start_at: datetime,
graph_engine: "GraphEngine", graph_engine: "GraphEngine",
@ -559,10 +559,11 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter_parallel( def _run_single_iter_parallel(
self, self,
*,
flask_app: Flask, flask_app: Flask,
q: Queue, q: Queue,
iterator_list_value: list[str], iterator_list_value: Sequence[str],
inputs: dict[str, list], inputs: Mapping[str, list],
outputs: list, outputs: list,
start_at: datetime, start_at: datetime,
graph_engine: "GraphEngine", graph_engine: "GraphEngine",