chore(list_operator): refine exception handling for error specificity (#10206)

This commit is contained in:
-LAN- 2024-11-03 11:55:19 +08:00 committed by GitHub
parent ec6a03afdd
commit 1432c268a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 57 deletions

View File

@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""
pass
class InvalidFilterValueError(ListOperatorError):
pass
class InvalidKeyError(ListOperatorError):
pass
class InvalidConditionError(ListOperatorError):
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Literal from typing import Literal, Union
from core.file import File from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]): class ListOperatorNode(BaseNode[ListOperatorNodeData]):
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
) )
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = ( error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment" "or ArrayStringSegment"
@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) )
if isinstance(variable, ArrayFileSegment): if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value] process_data["variable"] = [item.to_dict() for item in variable.value]
else: else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value process_data["variable"] = variable.value
# Filter try:
if self.node_data.filter_by.enabled: # Filter
for condition in self.node_data.filter_by.conditions: if self.node_data.filter_by.enabled:
if isinstance(variable, ArrayStringSegment): variable = self._apply_filter(variable)
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(condition.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
# Order # Order
if self.node_data.order_by.enabled: if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value) if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment): elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value) if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment): elif isinstance(variable, ArrayFileSegment):
result = _order_file( if isinstance(condition.value, str):
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
else:
value = condition.value
filter_func = _get_file_filter_func(
key=condition.key,
condition=condition.comparison_operator,
value=value,
) )
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
return variable
# Slice def _apply_order(
if self.node_data.limit.enabled: self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
result = variable.value[: self.node_data.limit.size] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
outputs = { def _apply_slice(
"result": variable.value, self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
"first_record": variable.value[0] if variable.value else None, ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
"last_record": variable.value[-1] if variable.value else None, result = variable.value[: self.node_data.limit.size]
} return variable.model_copy(update={"value": result})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
case "size": case "size":
return lambda x: x.size return lambda x: x.size
case _: case _:
raise ValueError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "url": case "url":
return lambda x: x.remote_url or "" return lambda x: x.remote_url or ""
case _: case _:
raise ValueError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "not empty": case "not empty":
return lambda x: x != "" return lambda x: x != ""
case _: case _:
raise ValueError(f"Invalid condition: {condition}") raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "not in": case "not in":
return lambda x: not _in(value)(x) return lambda x: not _in(value)(x)
case _: case _:
raise ValueError(f"Invalid condition: {condition}") raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
case "": case "":
return _ge(value) return _ge(value)
case _: case _:
raise ValueError(f"Invalid condition: {condition}") raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_number_func(key=key) extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else: else:
raise ValueError(f"Invalid key: {key}") raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str): def _contains(value: str):