diff --git a/api/.env.example b/api/.env.example index 880453161e..73815ce371 100644 --- a/api/.env.example +++ b/api/.env.example @@ -310,6 +310,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 MULTIMODAL_SEND_FORMAT=base64 PROMPT_GENERATION_MAX_TOKENS=512 CODE_GENERATION_MAX_TOKENS=1024 +STRUCTURED_OUTPUT_MAX_TOKENS=1024 # Mail configuration, support: resend, smtp MAIL_TYPE= diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 8518d34a8e..0ab2aaafbb 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -85,5 +85,37 @@ class RuleCodeGenerateApi(Resource): return code_result +class RuleStructuredOutputGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + account = current_user + structured_output_max_tokens = int(os.getenv("STRUCTURED_OUTPUT_MAX_TOKENS", "1024")) + try: + structured_output = LLMGenerator.generate_structured_output( + tenant_id=account.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + max_tokens=structured_output_max_tokens, + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + return structured_output + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") +api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 75687f9ae3..2627b7347e 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from core.llm_generator.prompts import ( GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + STRUCTURED_OUTPUT_GENERATE_TEMPLATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -340,3 +341,43 @@ class LLMGenerator: answer = cast(str, response.message.content) return answer.strip() + + @classmethod + def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict, max_tokens: int): + prompt_template = PromptTemplateParser(STRUCTURED_OUTPUT_GENERATE_TEMPLATE) + + prompt = prompt_template.format( + inputs={ + "INSTRUCTION": instruction, + }, + remove_template_variables=False, + ) + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + prompt_messages = [UserPromptMessage(content=prompt)] + model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_json_schema = cast(str, response.message.content) + return {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") + return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index f9411e9ec7..af99ca5e2b 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -220,3 +220,112 @@ Here is the task description: {{INPUT_TEXT}} You just need to generate the output """ # noqa: E501 + +STRUCTURED_OUTPUT_GENERATE_TEMPLATE = """ +Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements. + +## Instructions: + +1. Analyze the user's description of their data needs +2. Identify each property that should be included in the schema +3. Determine the appropriate data type for each property +4. Decide which properties should be required +5. Generate a complete JSON Schema with proper syntax +6. Include appropriate constraints when specified (min/max values, patterns, formats) +7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting. +8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly. + +## Examples: + +### Example 1: +**User Input:** I need name and age +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name", "age"] +} + +### Example 2: +**User Input:** I want to store information about books including title, author, publication year and optional page count +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "title": { "type": "string" }, + "author": { "type": "string" }, + "publicationYear": { "type": "integer" }, + "pageCount": { "type": "integer" } + }, + "required": ["title", "author", "publicationYear"] +} + +### Example 3: +**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18) +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + }, + "password": { + "type": "string", + "minLength": 8 + }, + "age": { + "type": "integer", + "minimum": 18 + } + }, + "required": ["email", "password", "age"] +} + +### Example 4: +**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist. +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "properties": { + "songs": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "id": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "aritst": { + "type": "string" + } + }, + "required": [ + "name", + "id", + "duration", + "aritst" + ] + } + } + } + }, + "required": [ + "songs" + ] +} + +Now, generate a JSON Schema based on my description: +**User Input:** {{INSTRUCTION}} +**JSON Schema Output:** +""" # noqa: E501 diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0d5e8a3e4b..d202690e67 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,7 +1,11 @@ +import json import logging from collections.abc import Callable, Generator, Iterable, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload +from packaging import version +from packaging.version import Version + from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle @@ -9,8 +13,8 @@ from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError @@ -20,7 +24,9 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager +from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from extensions.ext_redis import redis_client from models.provider import ProviderType @@ -160,6 +166,13 @@ class ModelInstance: raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) + if model_parameters and model_parameters.get("structured_output_schema"): + result = self._handle_structured_output( + model_parameters=model_parameters, + prompt=prompt_messages, + ) + prompt_messages = result["prompt"] + model_parameters = result["parameters"] return cast( Union[LLMResult, Generator], self._round_robin_invoke( @@ -410,6 +423,83 @@ class ModelInstance: model=self.model, credentials=self.credentials, language=language ) + def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict: + """ + Handle structured output + + :param model_parameters: model parameters + :param provider: provider name + :return: updated model parameters + """ + structured_output_schema = model_parameters.pop("structured_output_schema") + if not structured_output_schema: + raise ValueError("Please provide a valid structured output schema") + + try: + schema = json.loads(structured_output_schema) + except json.JSONDecodeError: + raise ValueError("structured_output_schema is not valid JSON format") + + model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model) + if not model_schema: + raise ValueError("Unable to fetch model schema") + + supported_schema_keys = ["json_schema", "format"] + rules = model_schema.parameter_rules + schema_key = next((rule.name for rule in rules if rule.name in supported_schema_keys), None) + + if schema_key == "json_schema": + name = {"name": "llm_response"} + if "gemini" in self.model: + + def remove_additional_properties(schema): + if isinstance(schema, dict): + for key, value in list(schema.items()): + if key == "additionalProperties": + del schema[key] + else: + remove_additional_properties(value) + + remove_additional_properties(schema) + schema_json = schema + else: + schema_json = {"schema": schema, **name} + + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + elif schema_key == "format" and self.plugin_version > version.parse("0.0.3"): + model_parameters["format"] = json.dumps(schema, ensure_ascii=False) + else: + content = prompt[-1].content if isinstance(prompt[-1].content, str) else "" + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", structured_output_schema).replace( + "{{question}}", content + ) + structured_output_prompt_message = UserPromptMessage(content=structured_output_prompt) + prompt = list(prompt[:-1]) + [structured_output_prompt_message] + return {"prompt": prompt, "parameters": model_parameters} + for rule in rules: + if rule.name == "response_format": + model_parameters["response_format"] = "JSON" if "JSON" in rule.options else "json_schema" + return {"prompt": prompt, "parameters": model_parameters} + + def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None: + """ + Fetch model schema + """ + model_provider = ModelProviderFactory(self.model_type_instance.tenant_id) + return model_provider.get_model_schema( + provider=provider, model_type=model_type, model=model, credentials=self.credentials + ) + + @property + def plugin_version(self) -> Version: + """ + Check if the model is a plugin model + """ + return version.parse( + self.model_type_instance.plugin_model_provider.plugin_unique_identifier.split(":")[1].split("@")[0] + ) + class ModelManager: def __init__(self) -> None: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bf54fdb80c..486b4b01af 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + structured_output: dict | None = None + structured_output_enabled: bool = False @field_validator("prompt_config", mode="before") @classmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..438256c7ef 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,5 +1,6 @@ import json import logging +import re from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast @@ -57,6 +58,7 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) +from core.workflow.utils.structured_output.utils import parse_partial_json from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -192,7 +194,19 @@ class LLMNode(BaseNode[LLMNodeData]): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - + if self.node_data.structured_output_enabled and self.node_data.structured_output: + structured_output = {} + try: + structured_output = parse_partial_json(result_text) + except json.JSONDecodeError: + # Try to find JSON string within triple backticks + _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL) + match = _json_markdown_re.search(result_text) + # If no match found, assume the entire string is a JSON string + # Else, use the content within the backticks + json_str = result_text if match is None else match.group(2) + structured_output = parse_partial_json(json_str) + outputs["structured_output"] = structured_output yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -499,6 +513,10 @@ class LLMNode(BaseNode[LLMNodeData]): # model config completion_params = node_data_model.completion_params + if self.node_data.structured_output_enabled and self.node_data.structured_output: + completion_params["structured_output_schema"] = json.dumps( + self.node_data.structured_output.get("schema", {}), ensure_ascii=False + ) stop = [] if "stop" in completion_params: stop = completion_params["stop"] diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py new file mode 100644 index 0000000000..4fee7bb34d --- /dev/null +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -0,0 +1,21 @@ +STRUCTURED_OUTPUT_PROMPT = """ +You’re a helpful AI assistant. You could answer questions and output in JSON format. + +eg1: + Here is the JSON schema: + {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} + +Here is the JSON schema: +{{schema}} + +Here is the user's question: +{{question}} +output: + +""" # noqa: E501 diff --git a/api/core/workflow/utils/structured_output/utils.py b/api/core/workflow/utils/structured_output/utils.py new file mode 100644 index 0000000000..b8dd3b7273 --- /dev/null +++ b/api/core/workflow/utils/structured_output/utils.py @@ -0,0 +1,81 @@ +import json +from typing import Any + + +def parse_partial_json(s: str, *, strict: bool = False) -> Any: + """Parse a JSON string that may be missing closing braces. + + Args: + s: The JSON string to parse. + strict: Whether to use strict parsing. Defaults to False. + + Returns: + The parsed JSON object as a Python dictionary. + """ + # Attempt to parse the string as-is. + try: + return json.loads(s, strict=strict) + except json.JSONDecodeError: + pass + + # Initialize variables. + new_chars = [] + stack = [] + is_inside_string = False + escaped = False + + # Process each character in the string one at a time. + for char in s: + if is_inside_string: + if char == '"' and not escaped: + is_inside_string = False + elif char == "\n" and not escaped: + char = "\\n" # Replace the newline character with the escape sequence. + elif char == "\\": + escaped = not escaped + else: + escaped = False + else: + if char == '"': + is_inside_string = True + escaped = False + elif char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char in ("}", "]"): + if stack and stack[-1] == char: + stack.pop() + else: + # Mismatched closing character; the input is malformed. + return {} + + # Append the processed character to the new string. + new_chars.append(char) + + # If we're still inside a string at the end of processing, + # we need to close the string. + if is_inside_string: + if escaped: # Remoe unterminated escape character + new_chars.pop() + new_chars.append('"') + + # Reverse the stack to get the closing characters. + stack.reverse() + + # Try to parse mods of string until we succeed or run out of characters. + while new_chars: + # Close any remaining open structures in the reverse + # order that they were opened. + # Attempt to parse the modified string as JSON. + try: + return json.loads("".join(new_chars + stack), strict=strict) + except json.JSONDecodeError: + # If we still can't parse the string as JSON, + # try removing the last character + new_chars.pop() + + # If we got here, we ran out of characters to remove + # and still couldn't parse the string as JSON, so return the parse error + # for the original string. + return json.loads(s, strict=strict) diff --git a/docker/.env.example b/docker/.env.example index 29073fa1b0..eb1768587d 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -595,6 +595,12 @@ PROMPT_GENERATION_MAX_TOKENS=512 # Default: 1024 tokens. CODE_GENERATION_MAX_TOKENS=1024 +# The maximum number of tokens allowed for structured output. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating structured output in the structured output tool. +# Default: 1024 tokens. +STRUCTURED_OUTPUT_MAX_TOKENS=1024 + # ------------------------------ # Multi-modal Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 6c95ddd1f2..9829b43bf2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -260,6 +260,7 @@ x-shared-env: &shared-api-worker-env SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true} PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} + STRUCTURED_OUTPUT_MAX_TOKENS: ${STRUCTURED_OUTPUT_MAX_TOKENS:-1024} MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64} UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}