diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 7b1b8cf483..be43639fc0 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,4 +1,5 @@ import json +import logging import uuid from collections.abc import Mapping, Sequence from typing import Any, Optional, cast @@ -58,6 +59,30 @@ from .prompts import ( FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, ) +logger = logging.getLogger(__name__) + + +def extract_json(text): + """ + From a given JSON started from '{' or '[' extract the complete JSON object. + """ + stack = [] + for i, c in enumerate(text): + if c in {"{", "["}: + stack.append(c) + elif c in {"}", "]"}: + # check if stack is empty + if not stack: + return text[:i] + # check if the last element in stack is matching + if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): + stack.pop() + if not stack: + return text[: i + 1] + else: + return text[:i] + return None + class ParameterExtractorNode(LLMNode): """ @@ -594,27 +619,6 @@ class ParameterExtractorNode(LLMNode): Extract complete json response. """ - def extract_json(text): - """ - From a given JSON started from '{' or '[' extract the complete JSON object. - """ - stack = [] - for i, c in enumerate(text): - if c in {"{", "["}: - stack.append(c) - elif c in {"}", "]"}: - # check if stack is empty - if not stack: - return text[:i] - # check if the last element in stack is matching - if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): - stack.pop() - if not stack: - return text[: i + 1] - else: - return text[:i] - return None - # extract json from the text for idx in range(len(result)): if result[idx] == "{" or result[idx] == "[": @@ -624,6 +628,7 @@ class ParameterExtractorNode(LLMNode): return cast(dict, json.loads(json_str)) except Exception: pass + logger.info(f"extra error: {result}") return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: @@ -633,7 +638,18 @@ class ParameterExtractorNode(LLMNode): if not tool_call or not tool_call.function.arguments: return None - return cast(dict, json.loads(tool_call.function.arguments)) + result = tool_call.function.arguments + # extract json from the arguments + for idx in range(len(result)): + if result[idx] == "{" or result[idx] == "[": + json_str = extract_json(result[idx:]) + if json_str: + try: + return cast(dict, json.loads(json_str)) + except Exception: + pass + logger.info(f"extra error: {result}") + return None def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index ca055f5cc5..5c6bb82024 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,6 +5,7 @@ from typing import Optional from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph @@ -311,6 +312,46 @@ def test_extract_json_response(): assert result["location"] == "kawaii" +def test_extract_json_from_tool_call(): + """ + Test extract json response. + """ + + node = init_parameter_extractor_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "parameter-extractor", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": {}, + }, + "query": ["sys", "query"], + "parameters": [{"name": "location", "type": "string", "description": "location", "required": True}], + "reasoning_mode": "prompt", + "instruction": "{{#sys.query#}}", + "memory": None, + }, + }, + ) + + result = node._extract_json_from_tool_call( + AssistantPromptMessage.ToolCall( + id="llm", + type="parameter-extractor", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="foo", arguments="""{"location":"kawaii"}{"location": 1}""" + ), + ) + ) + + assert result is not None + assert result["location"] == "kawaii" + + def test_chat_parameter_extractor_with_memory(setup_model_mock): """ Test chat parameter extractor with memory.