mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 14:45:59 +08:00
feat: structured output
This commit is contained in:
parent
5d8b32a249
commit
0dcbdfcb8d
@ -310,6 +310,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS=1024
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
|
@ -85,5 +85,37 @@ class RuleCodeGenerateApi(Resource):
|
||||
return code_result
|
||||
|
||||
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
structured_output_max_tokens = int(os.getenv("STRUCTURED_OUTPUT_MAX_TOKENS", "1024"))
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
max_tokens=structured_output_max_tokens,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return structured_output
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
|
||||
|
@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
|
||||
GENERATOR_QA_PROMPT,
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
STRUCTURED_OUTPUT_GENERATE_TEMPLATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
@ -340,3 +341,43 @@ class LLMGenerator:
|
||||
|
||||
answer = cast(str, response.message.content)
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict, max_tokens: int):
|
||||
prompt_template = PromptTemplateParser(STRUCTURED_OUTPUT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
generated_json_schema = cast(str, response.message.content)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
@ -220,3 +220,112 @@ Here is the task description: {{INPUT_TEXT}}
|
||||
|
||||
You just need to generate the output
|
||||
""" # noqa: E501
|
||||
|
||||
STRUCTURED_OUTPUT_GENERATE_TEMPLATE = """
|
||||
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
|
||||
|
||||
## Instructions:
|
||||
|
||||
1. Analyze the user's description of their data needs
|
||||
2. Identify each property that should be included in the schema
|
||||
3. Determine the appropriate data type for each property
|
||||
4. Decide which properties should be required
|
||||
5. Generate a complete JSON Schema with proper syntax
|
||||
6. Include appropriate constraints when specified (min/max values, patterns, formats)
|
||||
7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
|
||||
8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
|
||||
|
||||
## Examples:
|
||||
|
||||
### Example 1:
|
||||
**User Input:** I need name and age
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"age": { "type": "number" }
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
|
||||
### Example 2:
|
||||
**User Input:** I want to store information about books including title, author, publication year and optional page count
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": { "type": "string" },
|
||||
"author": { "type": "string" },
|
||||
"publicationYear": { "type": "integer" },
|
||||
"pageCount": { "type": "integer" }
|
||||
},
|
||||
"required": ["title", "author", "publicationYear"]
|
||||
}
|
||||
|
||||
### Example 3:
|
||||
**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"minLength": 8
|
||||
},
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 18
|
||||
}
|
||||
},
|
||||
"required": ["email", "password", "age"]
|
||||
}
|
||||
|
||||
### Example 4:
|
||||
**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
|
||||
**JSON Schema Output:**
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"properties": {
|
||||
"songs": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "string"
|
||||
},
|
||||
"aritst": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"name",
|
||||
"id",
|
||||
"duration",
|
||||
"aritst"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"songs"
|
||||
]
|
||||
}
|
||||
|
||||
Now, generate a JSON Schema based on my description:
|
||||
**User Input:** {{INSTRUCTION}}
|
||||
**JSON Schema Output:**
|
||||
""" # noqa: E501
|
||||
|
@ -1,7 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
@ -9,8 +13,8 @@ from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
@ -20,7 +24,9 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderType
|
||||
|
||||
@ -160,6 +166,13 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
if model_parameters and model_parameters.get("structured_output_schema"):
|
||||
result = self._handle_structured_output(
|
||||
model_parameters=model_parameters,
|
||||
prompt=prompt_messages,
|
||||
)
|
||||
prompt_messages = result["prompt"]
|
||||
model_parameters = result["parameters"]
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
@ -410,6 +423,83 @@ class ModelInstance:
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict:
|
||||
"""
|
||||
Handle structured output
|
||||
|
||||
:param model_parameters: model parameters
|
||||
:param provider: provider name
|
||||
:return: updated model parameters
|
||||
"""
|
||||
structured_output_schema = model_parameters.pop("structured_output_schema")
|
||||
if not structured_output_schema:
|
||||
raise ValueError("Please provide a valid structured output schema")
|
||||
|
||||
try:
|
||||
schema = json.loads(structured_output_schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("structured_output_schema is not valid JSON format")
|
||||
|
||||
model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model)
|
||||
if not model_schema:
|
||||
raise ValueError("Unable to fetch model schema")
|
||||
|
||||
supported_schema_keys = ["json_schema", "format"]
|
||||
rules = model_schema.parameter_rules
|
||||
schema_key = next((rule.name for rule in rules if rule.name in supported_schema_keys), None)
|
||||
|
||||
if schema_key == "json_schema":
|
||||
name = {"name": "llm_response"}
|
||||
if "gemini" in self.model:
|
||||
|
||||
def remove_additional_properties(schema):
|
||||
if isinstance(schema, dict):
|
||||
for key, value in list(schema.items()):
|
||||
if key == "additionalProperties":
|
||||
del schema[key]
|
||||
else:
|
||||
remove_additional_properties(value)
|
||||
|
||||
remove_additional_properties(schema)
|
||||
schema_json = schema
|
||||
else:
|
||||
schema_json = {"schema": schema, **name}
|
||||
|
||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||
|
||||
elif schema_key == "format" and self.plugin_version > version.parse("0.0.3"):
|
||||
model_parameters["format"] = json.dumps(schema, ensure_ascii=False)
|
||||
else:
|
||||
content = prompt[-1].content if isinstance(prompt[-1].content, str) else ""
|
||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", structured_output_schema).replace(
|
||||
"{{question}}", content
|
||||
)
|
||||
structured_output_prompt_message = UserPromptMessage(content=structured_output_prompt)
|
||||
prompt = list(prompt[:-1]) + [structured_output_prompt_message]
|
||||
return {"prompt": prompt, "parameters": model_parameters}
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
model_parameters["response_format"] = "JSON" if "JSON" in rule.options else "json_schema"
|
||||
return {"prompt": prompt, "parameters": model_parameters}
|
||||
|
||||
def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
"""
|
||||
model_provider = ModelProviderFactory(self.model_type_instance.tenant_id)
|
||||
return model_provider.get_model_schema(
|
||||
provider=provider, model_type=model_type, model=model, credentials=self.credentials
|
||||
)
|
||||
|
||||
@property
|
||||
def plugin_version(self) -> Version:
|
||||
"""
|
||||
Check if the model is a plugin model
|
||||
"""
|
||||
return version.parse(
|
||||
self.model_type_instance.plugin_model_provider.plugin_unique_identifier.split(":")[1].split("@")[0]
|
||||
)
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self) -> None:
|
||||
|
@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output_enabled: bool = False
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
@ -57,6 +58,7 @@ from core.workflow.nodes.event import (
|
||||
RunRetrieverResourceEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.utils.structured_output.utils import parse_partial_json
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@ -192,7 +194,19 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
if self.node_data.structured_output_enabled and self.node_data.structured_output:
|
||||
structured_output = {}
|
||||
try:
|
||||
structured_output = parse_partial_json(result_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON string within triple backticks
|
||||
_json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL)
|
||||
match = _json_markdown_re.search(result_text)
|
||||
# If no match found, assume the entire string is a JSON string
|
||||
# Else, use the content within the backticks
|
||||
json_str = result_text if match is None else match.group(2)
|
||||
structured_output = parse_partial_json(json_str)
|
||||
outputs["structured_output"] = structured_output
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -499,6 +513,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
if self.node_data.structured_output_enabled and self.node_data.structured_output:
|
||||
completion_params["structured_output_schema"] = json.dumps(
|
||||
self.node_data.structured_output.get("schema", {}), ensure_ascii=False
|
||||
)
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
|
21
api/core/workflow/utils/structured_output/prompt.py
Normal file
21
api/core/workflow/utils/structured_output/prompt.py
Normal file
@ -0,0 +1,21 @@
|
||||
STRUCTURED_OUTPUT_PROMPT = """
|
||||
You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||
|
||||
eg1:
|
||||
Here is the JSON schema:
|
||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||
|
||||
Here is the user's question:
|
||||
My name is John Doe and I am 30 years old.
|
||||
|
||||
output:
|
||||
{"name": "John Doe", "age": 30}
|
||||
|
||||
Here is the JSON schema:
|
||||
{{schema}}
|
||||
|
||||
Here is the user's question:
|
||||
{{question}}
|
||||
output:
|
||||
|
||||
""" # noqa: E501
|
81
api/core/workflow/utils/structured_output/utils.py
Normal file
81
api/core/workflow/utils/structured_output/utils.py
Normal file
@ -0,0 +1,81 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
"""Parse a JSON string that may be missing closing braces.
|
||||
|
||||
Args:
|
||||
s: The JSON string to parse.
|
||||
strict: Whether to use strict parsing. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s, strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Initialize variables.
|
||||
new_chars = []
|
||||
stack = []
|
||||
is_inside_string = False
|
||||
escaped = False
|
||||
|
||||
# Process each character in the string one at a time.
|
||||
for char in s:
|
||||
if is_inside_string:
|
||||
if char == '"' and not escaped:
|
||||
is_inside_string = False
|
||||
elif char == "\n" and not escaped:
|
||||
char = "\\n" # Replace the newline character with the escape sequence.
|
||||
elif char == "\\":
|
||||
escaped = not escaped
|
||||
else:
|
||||
escaped = False
|
||||
else:
|
||||
if char == '"':
|
||||
is_inside_string = True
|
||||
escaped = False
|
||||
elif char == "{":
|
||||
stack.append("}")
|
||||
elif char == "[":
|
||||
stack.append("]")
|
||||
elif char in ("}", "]"):
|
||||
if stack and stack[-1] == char:
|
||||
stack.pop()
|
||||
else:
|
||||
# Mismatched closing character; the input is malformed.
|
||||
return {}
|
||||
|
||||
# Append the processed character to the new string.
|
||||
new_chars.append(char)
|
||||
|
||||
# If we're still inside a string at the end of processing,
|
||||
# we need to close the string.
|
||||
if is_inside_string:
|
||||
if escaped: # Remoe unterminated escape character
|
||||
new_chars.pop()
|
||||
new_chars.append('"')
|
||||
|
||||
# Reverse the stack to get the closing characters.
|
||||
stack.reverse()
|
||||
|
||||
# Try to parse mods of string until we succeed or run out of characters.
|
||||
while new_chars:
|
||||
# Close any remaining open structures in the reverse
|
||||
# order that they were opened.
|
||||
# Attempt to parse the modified string as JSON.
|
||||
try:
|
||||
return json.loads("".join(new_chars + stack), strict=strict)
|
||||
except json.JSONDecodeError:
|
||||
# If we still can't parse the string as JSON,
|
||||
# try removing the last character
|
||||
new_chars.pop()
|
||||
|
||||
# If we got here, we ran out of characters to remove
|
||||
# and still couldn't parse the string as JSON, so return the parse error
|
||||
# for the original string.
|
||||
return json.loads(s, strict=strict)
|
@ -595,6 +595,12 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
||||
# Default: 1024 tokens.
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# The maximum number of tokens allowed for structured output.
|
||||
# This setting controls the upper limit of tokens that can be used by the LLM
|
||||
# when generating structured output in the structured output tool.
|
||||
# Default: 1024 tokens.
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS=1024
|
||||
|
||||
# ------------------------------
|
||||
# Multi-modal Configuration
|
||||
# ------------------------------
|
||||
|
@ -260,6 +260,7 @@ x-shared-env: &shared-api-worker-env
|
||||
SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true}
|
||||
PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512}
|
||||
CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024}
|
||||
STRUCTURED_OUTPUT_MAX_TOKENS: ${STRUCTURED_OUTPUT_MAX_TOKENS:-1024}
|
||||
MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64}
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
|
||||
|
Loading…
x
Reference in New Issue
Block a user