mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 00:58:57 +08:00
robust for json parser (#17687)
This commit is contained in:
parent
0e0220bdbf
commit
5541a1f80e
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
@ -58,6 +59,30 @@ from .prompts import (
|
|||||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
|
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json(text):
|
||||||
|
"""
|
||||||
|
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||||
|
"""
|
||||||
|
stack = []
|
||||||
|
for i, c in enumerate(text):
|
||||||
|
if c in {"{", "["}:
|
||||||
|
stack.append(c)
|
||||||
|
elif c in {"}", "]"}:
|
||||||
|
# check if stack is empty
|
||||||
|
if not stack:
|
||||||
|
return text[:i]
|
||||||
|
# check if the last element in stack is matching
|
||||||
|
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
|
||||||
|
stack.pop()
|
||||||
|
if not stack:
|
||||||
|
return text[: i + 1]
|
||||||
|
else:
|
||||||
|
return text[:i]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNode(LLMNode):
|
class ParameterExtractorNode(LLMNode):
|
||||||
"""
|
"""
|
||||||
@ -594,27 +619,6 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
Extract complete json response.
|
Extract complete json response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def extract_json(text):
|
|
||||||
"""
|
|
||||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
|
||||||
"""
|
|
||||||
stack = []
|
|
||||||
for i, c in enumerate(text):
|
|
||||||
if c in {"{", "["}:
|
|
||||||
stack.append(c)
|
|
||||||
elif c in {"}", "]"}:
|
|
||||||
# check if stack is empty
|
|
||||||
if not stack:
|
|
||||||
return text[:i]
|
|
||||||
# check if the last element in stack is matching
|
|
||||||
if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["):
|
|
||||||
stack.pop()
|
|
||||||
if not stack:
|
|
||||||
return text[: i + 1]
|
|
||||||
else:
|
|
||||||
return text[:i]
|
|
||||||
return None
|
|
||||||
|
|
||||||
# extract json from the text
|
# extract json from the text
|
||||||
for idx in range(len(result)):
|
for idx in range(len(result)):
|
||||||
if result[idx] == "{" or result[idx] == "[":
|
if result[idx] == "{" or result[idx] == "[":
|
||||||
@ -624,6 +628,7 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
return cast(dict, json.loads(json_str))
|
return cast(dict, json.loads(json_str))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
logger.info(f"extra error: {result}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||||
@ -633,7 +638,18 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
if not tool_call or not tool_call.function.arguments:
|
if not tool_call or not tool_call.function.arguments:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return cast(dict, json.loads(tool_call.function.arguments))
|
result = tool_call.function.arguments
|
||||||
|
# extract json from the arguments
|
||||||
|
for idx in range(len(result)):
|
||||||
|
if result[idx] == "{" or result[idx] == "[":
|
||||||
|
json_str = extract_json(result[idx:])
|
||||||
|
if json_str:
|
||||||
|
try:
|
||||||
|
return cast(dict, json.loads(json_str))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.info(f"extra error: {result}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||||
"""
|
"""
|
||||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.model_runtime.entities import AssistantPromptMessage
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
@ -311,6 +312,46 @@ def test_extract_json_response():
|
|||||||
assert result["location"] == "kawaii"
|
assert result["location"] == "kawaii"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_json_from_tool_call():
|
||||||
|
"""
|
||||||
|
Test extract json response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
node = init_parameter_extractor_node(
|
||||||
|
config={
|
||||||
|
"id": "llm",
|
||||||
|
"data": {
|
||||||
|
"title": "123",
|
||||||
|
"type": "parameter-extractor",
|
||||||
|
"model": {
|
||||||
|
"provider": "langgenius/openai/openai",
|
||||||
|
"name": "gpt-3.5-turbo-instruct",
|
||||||
|
"mode": "completion",
|
||||||
|
"completion_params": {},
|
||||||
|
},
|
||||||
|
"query": ["sys", "query"],
|
||||||
|
"parameters": [{"name": "location", "type": "string", "description": "location", "required": True}],
|
||||||
|
"reasoning_mode": "prompt",
|
||||||
|
"instruction": "{{#sys.query#}}",
|
||||||
|
"memory": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = node._extract_json_from_tool_call(
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id="llm",
|
||||||
|
type="parameter-extractor",
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name="foo", arguments="""{"location":"kawaii"}{"location": 1}"""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["location"] == "kawaii"
|
||||||
|
|
||||||
|
|
||||||
def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
||||||
"""
|
"""
|
||||||
Test chat parameter extractor with memory.
|
Test chat parameter extractor with memory.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user