mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 01:15:53 +08:00
fix: add json parse step
This commit is contained in:
parent
14b65efcf7
commit
cb664f8c97
@ -362,7 +362,7 @@ class LLMGenerator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
model_parameters = model_config.get("model_parameters", {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = cast(
|
response = cast(
|
||||||
|
@ -198,14 +198,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
structured_output = {}
|
structured_output = {}
|
||||||
try:
|
try:
|
||||||
structured_output = parse_partial_json(result_text)
|
structured_output = parse_partial_json(result_text)
|
||||||
except json.JSONDecodeError:
|
except (json.JSONDecodeError, ValueError):
|
||||||
# Try to find JSON string within triple backticks
|
# Try to find JSON string within triple backticks
|
||||||
_json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
|
_json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
|
||||||
match = _json_markdown_re.search(result_text)
|
match = _json_markdown_re.search(result_text)
|
||||||
# If no match found, assume the entire string is a JSON string
|
# If no match found, assume the entire string is a JSON string
|
||||||
# Else, use the content within the backticks
|
# Else, use the content within the backticks
|
||||||
json_str = result_text if match is None else match.group(2)
|
json_str = result_text if match is None else match.group(2)
|
||||||
|
try:
|
||||||
structured_output = parse_partial_json(json_str)
|
structured_output = parse_partial_json(json_str)
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||||
outputs["structured_output"] = structured_output
|
outputs["structured_output"] = structured_output
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
|
@ -78,4 +78,16 @@ def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
|||||||
# If we got here, we ran out of characters to remove
|
# If we got here, we ran out of characters to remove
|
||||||
# and still couldn't parse the string as JSON, so return the parse error
|
# and still couldn't parse the string as JSON, so return the parse error
|
||||||
# for the original string.
|
# for the original string.
|
||||||
|
try:
|
||||||
return json.loads(s, strict=strict)
|
return json.loads(s, strict=strict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return extract_json(s)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json(response: str):
|
||||||
|
try:
|
||||||
|
json_start = response.index("{")
|
||||||
|
json_end = response.rfind("}")
|
||||||
|
return json.loads(response[json_start : json_end + 1])
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
raise ValueError("output is not a valid json str")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user