feat: structured output

This commit is contained in:
Novice 2025-03-19 13:39:40 +08:00
parent 5d8b32a249
commit 0dcbdfcb8d
11 changed files with 405 additions and 3 deletions

View File

@ -310,6 +310,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
STRUCTURED_OUTPUT_MAX_TOKENS=1024
# Mail configuration, support: resend, smtp
MAIL_TYPE=

View File

@ -85,5 +85,37 @@ class RuleCodeGenerateApi(Resource):
return code_result
class RuleStructuredOutputGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
account = current_user
structured_output_max_tokens = int(os.getenv("STRUCTURED_OUTPUT_MAX_TOKENS", "1024"))
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
max_tokens=structured_output_max_tokens,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return structured_output
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")

View File

@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
STRUCTURED_OUTPUT_GENERATE_TEMPLATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
@ -340,3 +341,43 @@ class LLMGenerator:
answer = cast(str, response.message.content)
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict, max_tokens: int):
prompt_template = PromptTemplateParser(STRUCTURED_OUTPUT_GENERATE_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": instruction,
},
remove_template_variables=False,
)
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
generated_json_schema = cast(str, response.message.content)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}

View File

@ -220,3 +220,112 @@ Here is the task description: {{INPUT_TEXT}}
You just need to generate the output
""" # noqa: E501
STRUCTURED_OUTPUT_GENERATE_TEMPLATE = """
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
## Instructions:
1. Analyze the user's description of their data needs
2. Identify each property that should be included in the schema
3. Determine the appropriate data type for each property
4. Decide which properties should be required
5. Generate a complete JSON Schema with proper syntax
6. Include appropriate constraints when specified (min/max values, patterns, formats)
7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
## Examples:
### Example 1:
**User Input:** I need name and age
**JSON Schema Output:**
{
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "number" }
},
"required": ["name", "age"]
}
### Example 2:
**User Input:** I want to store information about books including title, author, publication year and optional page count
**JSON Schema Output:**
{
"type": "object",
"properties": {
"title": { "type": "string" },
"author": { "type": "string" },
"publicationYear": { "type": "integer" },
"pageCount": { "type": "integer" }
},
"required": ["title", "author", "publicationYear"]
}
### Example 3:
**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
**JSON Schema Output:**
{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email"
},
"password": {
"type": "string",
"minLength": 8
},
"age": {
"type": "integer",
"minimum": 18
}
},
"required": ["email", "password", "age"]
}
### Example 4:
**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
**JSON Schema Output:**
{
"type": "object",
"properties": {
"properties": {
"songs": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"id": {
"type": "string"
},
"duration": {
"type": "string"
},
"aritst": {
"type": "string"
}
},
"required": [
"name",
"id",
"duration",
"aritst"
]
}
}
}
},
"required": [
"songs"
]
}
Now, generate a JSON Schema based on my description:
**User Input:** {{INSTRUCTION}}
**JSON Schema Output:**
""" # noqa: E501

View File

@ -1,7 +1,11 @@
import json
import logging
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from packaging import version
from packaging.version import Version
from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -9,8 +13,8 @@ from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
@ -20,7 +24,9 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from extensions.ext_redis import redis_client
from models.provider import ProviderType
@ -160,6 +166,13 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
if model_parameters and model_parameters.get("structured_output_schema"):
result = self._handle_structured_output(
model_parameters=model_parameters,
prompt=prompt_messages,
)
prompt_messages = result["prompt"]
model_parameters = result["parameters"]
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -410,6 +423,83 @@ class ModelInstance:
model=self.model, credentials=self.credentials, language=language
)
def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict:
"""
Handle structured output
:param model_parameters: model parameters
:param provider: provider name
:return: updated model parameters
"""
structured_output_schema = model_parameters.pop("structured_output_schema")
if not structured_output_schema:
raise ValueError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
except json.JSONDecodeError:
raise ValueError("structured_output_schema is not valid JSON format")
model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model)
if not model_schema:
raise ValueError("Unable to fetch model schema")
supported_schema_keys = ["json_schema", "format"]
rules = model_schema.parameter_rules
schema_key = next((rule.name for rule in rules if rule.name in supported_schema_keys), None)
if schema_key == "json_schema":
name = {"name": "llm_response"}
if "gemini" in self.model:
def remove_additional_properties(schema):
if isinstance(schema, dict):
for key, value in list(schema.items()):
if key == "additionalProperties":
del schema[key]
else:
remove_additional_properties(value)
remove_additional_properties(schema)
schema_json = schema
else:
schema_json = {"schema": schema, **name}
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
elif schema_key == "format" and self.plugin_version > version.parse("0.0.3"):
model_parameters["format"] = json.dumps(schema, ensure_ascii=False)
else:
content = prompt[-1].content if isinstance(prompt[-1].content, str) else ""
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", structured_output_schema).replace(
"{{question}}", content
)
structured_output_prompt_message = UserPromptMessage(content=structured_output_prompt)
prompt = list(prompt[:-1]) + [structured_output_prompt_message]
return {"prompt": prompt, "parameters": model_parameters}
for rule in rules:
if rule.name == "response_format":
model_parameters["response_format"] = "JSON" if "JSON" in rule.options else "json_schema"
return {"prompt": prompt, "parameters": model_parameters}
def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_provider = ModelProviderFactory(self.model_type_instance.tenant_id)
return model_provider.get_model_schema(
provider=provider, model_type=model_type, model=model, credentials=self.credentials
)
@property
def plugin_version(self) -> Version:
"""
Check if the model is a plugin model
"""
return version.parse(
self.model_type_instance.plugin_model_provider.plugin_unique_identifier.split(":")[1].split("@")[0]
)
class ModelManager:
def __init__(self) -> None:

View File

@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output_enabled: bool = False
@field_validator("prompt_config", mode="before")
@classmethod

View File

@ -1,5 +1,6 @@
import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
@ -57,6 +58,7 @@ from core.workflow.nodes.event import (
RunRetrieverResourceEvent,
RunStreamChunkEvent,
)
from core.workflow.utils.structured_output.utils import parse_partial_json
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
@ -192,7 +194,19 @@ class LLMNode(BaseNode[LLMNodeData]):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
if self.node_data.structured_output_enabled and self.node_data.structured_output:
structured_output = {}
try:
structured_output = parse_partial_json(result_text)
except json.JSONDecodeError:
# Try to find JSON string within triple backticks
_json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
match = _json_markdown_re.search(result_text)
# If no match found, assume the entire string is a JSON string
# Else, use the content within the backticks
json_str = result_text if match is None else match.group(2)
structured_output = parse_partial_json(json_str)
outputs["structured_output"] = structured_output
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -499,6 +513,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# model config
completion_params = node_data_model.completion_params
if self.node_data.structured_output_enabled and self.node_data.structured_output:
completion_params["structured_output_schema"] = json.dumps(
self.node_data.structured_output.get("schema", {}), ensure_ascii=False
)
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]

View File

@ -0,0 +1,21 @@
STRUCTURED_OUTPUT_PROMPT = """
Youre a helpful AI assistant. You could answer questions and output in JSON format.
eg1:
Here is the JSON schema:
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
Here is the user's question:
My name is John Doe and I am 30 years old.
output:
{"name": "John Doe", "age": 30}
Here is the JSON schema:
{{schema}}
Here is the user's question:
{{question}}
output:
""" # noqa: E501

View File

@ -0,0 +1,81 @@
import json
from typing import Any
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
"""Parse a JSON string that may be missing closing braces.
Args:
s: The JSON string to parse.
strict: Whether to use strict parsing. Defaults to False.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass
# Initialize variables.
new_chars = []
stack = []
is_inside_string = False
escaped = False
# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char in ("}", "]"):
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return {}
# Append the processed character to the new string.
new_chars.append(char)
# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
if escaped: # Remoe unterminated escape character
new_chars.pop()
new_chars.append('"')
# Reverse the stack to get the closing characters.
stack.reverse()
# Try to parse mods of string until we succeed or run out of characters.
while new_chars:
# Close any remaining open structures in the reverse
# order that they were opened.
# Attempt to parse the modified string as JSON.
try:
return json.loads("".join(new_chars + stack), strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON,
# try removing the last character
new_chars.pop()
# If we got here, we ran out of characters to remove
# and still couldn't parse the string as JSON, so return the parse error
# for the original string.
return json.loads(s, strict=strict)

View File

@ -595,6 +595,12 @@ PROMPT_GENERATION_MAX_TOKENS=512
# Default: 1024 tokens.
CODE_GENERATION_MAX_TOKENS=1024
# The maximum number of tokens allowed for structured output.
# This setting controls the upper limit of tokens that can be used by the LLM
# when generating structured output in the structured output tool.
# Default: 1024 tokens.
STRUCTURED_OUTPUT_MAX_TOKENS=1024
# ------------------------------
# Multi-modal Configuration
# ------------------------------

View File

@ -260,6 +260,7 @@ x-shared-env: &shared-api-worker-env
SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true}
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
STRUCTURED_OUTPUT_MAX_TOKENS: ${STRUCTURED_OUTPUT_MAX_TOKENS:-1024}
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}