diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 2627b7347e..9cdccd8e7f 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -362,7 +362,7 @@ class LLMGenerator: ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} + model_parameters = model_config.get("model_parameters", {}) try: response = cast( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f9733a88d6..146fcec39a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -198,14 +198,17 @@ class LLMNode(BaseNode[LLMNodeData]): structured_output = {} try: structured_output = parse_partial_json(result_text) - except json.JSONDecodeError: + except (json.JSONDecodeError, ValueError): # Try to find JSON string within triple backticks _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL) match = _json_markdown_re.search(result_text) # If no match found, assume the entire string is a JSON string # Else, use the content within the backticks json_str = result_text if match is None else match.group(2) - structured_output = parse_partial_json(json_str) + try: + structured_output = parse_partial_json(json_str) + except (json.JSONDecodeError, ValueError) as e: + raise LLMNodeError(f"Failed to parse structured output: {e}") outputs["structured_output"] = structured_output yield RunCompletedEvent( run_result=NodeRunResult( diff --git a/api/core/workflow/utils/structured_output/utils.py b/api/core/workflow/utils/structured_output/utils.py index b8dd3b7273..16b8ffe91a 100644 --- a/api/core/workflow/utils/structured_output/utils.py +++ b/api/core/workflow/utils/structured_output/utils.py @@ -78,4 +78,16 @@ def parse_partial_json(s: str, *, strict: bool = False) -> Any: # If we got here, we ran out of characters to remove # and still couldn't parse the string as JSON, so return the parse error # for the original string. - return json.loads(s, strict=strict) + try: + return json.loads(s, strict=strict) + except json.JSONDecodeError: + return extract_json(s) + + +def extract_json(response: str): + try: + json_start = response.index("{") + json_end = response.rfind("}") + return json.loads(response[json_start : json_end + 1]) + except (json.JSONDecodeError, ValueError): + raise ValueError("output is not a valid json str")