refactor(iteration_node): use Sequence and Mapping in parameters (#11483)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-09 15:41:20 +08:00 committed by GitHub
parent c3c6a48059
commit 537068cfde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 23 deletions

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
@ -140,8 +141,8 @@ class BaseIterationEvent(GraphEngineEvent):
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
@ -153,18 +154,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -167,17 +167,17 @@ class IterationNode(BaseNode[IterationNodeData]):
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(), # type: ignore
q,
iterator_list_value,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
index,
item,
iter_run_map,
flask_app=current_app._get_current_object(), # type: ignore
q=q,
iterator_list_value=iterator_list_value,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
index=index,
item=item,
iter_run_map=iter_run_map,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
@ -370,9 +370,9 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter(
self,
*,
iterator_list_value: list[str],
iterator_list_value: Sequence[str],
variable_pool: VariablePool,
inputs: dict[str, list],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
@ -559,10 +559,11 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter_parallel(
self,
*,
flask_app: Flask,
q: Queue,
iterator_list_value: list[str],
inputs: dict[str, list],
iterator_list_value: Sequence[str],
inputs: Mapping[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",