robust for json parser (#17687)

This commit is contained in:
zxfishhack 2025-04-10 22:18:26 +08:00 committed by GitHub
parent 0e0220bdbf
commit 5541a1f80e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 22 deletions

View File

@ -1,4 +1,5 @@
import json import json
import logging
import uuid import uuid
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -58,6 +59,30 @@ from .prompts import (
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
) )
logger = logging.getLogger(__name__)
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c in {"{", "["}:
stack.append(c)
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[: i + 1]
else:
return text[:i]
return None
class ParameterExtractorNode(LLMNode): class ParameterExtractorNode(LLMNode):
""" """
@ -594,27 +619,6 @@ class ParameterExtractorNode(LLMNode):
Extract complete json response. Extract complete json response.
""" """
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c in {"{", "["}:
stack.append(c)
elif c in {"}", "]"}:
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
stack.pop()
if not stack:
return text[: i + 1]
else:
return text[:i]
return None
# extract json from the text # extract json from the text
for idx in range(len(result)): for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[": if result[idx] == "{" or result[idx] == "[":
@ -624,6 +628,7 @@ class ParameterExtractorNode(LLMNode):
return cast(dict, json.loads(json_str)) return cast(dict, json.loads(json_str))
except Exception: except Exception:
pass pass
logger.info(f"extra error: {result}")
return None return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
@ -633,7 +638,18 @@ class ParameterExtractorNode(LLMNode):
if not tool_call or not tool_call.function.arguments: if not tool_call or not tool_call.function.arguments:
return None return None
return cast(dict, json.loads(tool_call.function.arguments)) result = tool_call.function.arguments
# extract json from the arguments
for idx in range(len(result)):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
try:
return cast(dict, json.loads(json_str))
except Exception:
pass
logger.info(f"extra error: {result}")
return None
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
""" """

View File

@ -5,6 +5,7 @@ from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
@ -311,6 +312,46 @@ def test_extract_json_response():
assert result["location"] == "kawaii" assert result["location"] == "kawaii"
def test_extract_json_from_tool_call():
"""
Test extract json response.
"""
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "parameter-extractor",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {},
},
"query": ["sys", "query"],
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
"reasoning_mode": "prompt",
"instruction": "{{#sys.query#}}",
"memory": None,
},
},
)
result = node._extract_json_from_tool_call(
AssistantPromptMessage.ToolCall(
id="llm",
type="parameter-extractor",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name="foo", arguments="""{"location":"kawaii"}{"location": 1}"""
),
)
)
assert result is not None
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock): def test_chat_parameter_extractor_with_memory(setup_model_mock):
""" """
Test chat parameter extractor with memory. Test chat parameter extractor with memory.