From a056a9d6010a2dfb5bb9ced1f6d50104016b2176 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 22 Dec 2024 10:40:43 +0800 Subject: [PATCH] feat(code_node): add more check (#11949) Signed-off-by: -LAN- --- api/core/helper/code_executor/code_executor.py | 2 +- .../helper/code_executor/template_transformer.py | 13 +++++++++---- api/core/workflow/nodes/code/code_node.py | 16 +++++++--------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 011ff382ea..584e3e9698 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -118,7 +118,7 @@ class CodeExecutor: return response.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): """ Execute code :param language: code language diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index cf422fd023..605719747a 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -25,7 +25,7 @@ class TemplateTransformer(ABC): return runner_script, preload_script @classmethod - def extract_result_str_from_response(cls, response: str) -> str: + def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") @@ -33,15 +33,20 @@ class TemplateTransformer(ABC): return result @classmethod - def transform_response(cls, response: str): + def transform_response(cls, response: str) -> Mapping[str, Any]: """ Transform response to dict :param response: response :return: """ - result = json.loads(cls.extract_result_str_from_response(response)) + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") if not isinstance(result, dict): - raise ValueError("Result must be a dict") + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") return result @classmethod diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 6adf82c455..4e371ca436 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result, self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -116,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]): return value def _transform_result( - self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 - ) -> dict: - """ - Transform result - :param result: result - :param output_schema: output schema - :return: - """ + self, + result: Mapping[str, Any], + output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = "", + depth: int = 1, + ): if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")