diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 18025df39e..0d5e8a3e4b 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,4 +1,3 @@ -import json import logging from collections.abc import Callable, Generator, Iterable, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload @@ -10,8 +9,8 @@ from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError @@ -21,10 +20,7 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tts_model import TTSModel -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager -from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType -from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from extensions.ext_redis import redis_client from models.provider import ProviderType @@ -164,13 +160,6 @@ class ModelInstance: raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - if model_parameters and model_parameters.get("structured_output_schema"): - result = self._handle_structured_output( - model_parameters=model_parameters, - prompt=prompt_messages, - ) - prompt_messages = result["prompt"] - model_parameters = result["parameters"] return cast( Union[LLMResult, Generator], self._round_robin_invoke( @@ -421,190 +410,6 @@ class ModelInstance: model=self.model, credentials=self.credentials, language=language ) - def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict: - """ - Handle structured output for language models. - - This function processes schema-based structured outputs by either: - 1. Using native JSON schema support of the model when available - 2. Falling back to prompt-based structured output for models without native support - - :param model_parameters: Model configuration parameters - :param prompt: Sequence of prompt messages - :return: Dictionary with updated prompt and parameters - """ - # Extract and validate schema - structured_output_schema = model_parameters.pop("structured_output_schema") - if not structured_output_schema: - raise ValueError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - except json.JSONDecodeError: - raise ValueError("structured_output_schema is not valid JSON format") - - # Fetch model schema and validate - model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model) - if not model_schema: - raise ValueError("Unable to fetch model schema") - - rules = model_schema.parameter_rules - rule_names = [rule.name for rule in rules] - - # Handle based on model's JSON schema support capability - if "json_schema" in rule_names: - return self._handle_native_json_schema(schema, model_parameters, prompt, rules) - else: - return self._handle_prompt_based_schema(structured_output_schema, model_parameters, prompt, rules) - - def _handle_native_json_schema( - self, schema: dict, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list - ) -> dict: - """ - Handle structured output for models with native JSON schema support. - - :param schema: Parsed JSON schema - :param model_parameters: Model parameters to update - :param prompt: Sequence of prompt messages - :param rules: Model parameter rules - :return: Updated prompt and parameters - """ - # Process schema according to model requirements - schema_json = self._prepare_schema_for_model(schema) - - # Set JSON schema in parameters - model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) - - # Set appropriate response format if required by the model - for rule in rules: - if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value - - return {"prompt": prompt, "parameters": model_parameters} - - def _handle_prompt_based_schema( - self, schema_str: str, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list - ) -> dict: - """ - Handle structured output for models without native JSON schema support. - - :param schema_str: JSON schema as string - :param model_parameters: Model parameters to update - :param prompt: Sequence of prompt messages - :param rules: Model parameter rules - :return: Updated prompt and parameters - """ - # Extract content from the last prompt message - content = prompt[-1].content if isinstance(prompt[-1].content, str) else "" - - # Create structured output prompt - structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str).replace( - "{{question}}", content - ) - - # Replace the last user message with our structured output prompt - structured_prompt_message = UserPromptMessage(content=structured_output_prompt) - updated_prompt = list(prompt[:-1]) + [structured_prompt_message] - - # Set appropriate response format based on model capabilities - self._set_response_format(model_parameters, rules) - - return {"prompt": updated_prompt, "parameters": model_parameters} - - def _set_response_format(self, model_parameters: dict, rules: list) -> None: - """ - Set the appropriate response format parameter based on model rules. - - :param model_parameters: Model parameters to update - :param rules: Model parameter rules - """ - for rule in rules: - if rule.name == "response_format": - if ResponseFormat.JSON.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON.value - elif ResponseFormat.JSON_OBJECT.value in rule.options: - model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value - - def _prepare_schema_for_model(self, schema: dict) -> dict: - """ - Prepare JSON schema based on model requirements. - - Different models have different requirements for JSON schema formatting. - This function handles these differences. - - :param schema: The original JSON schema - :return: Processed schema compatible with the current model - """ - - def remove_additional_properties(schema: dict) -> None: - """ - Remove additionalProperties fields from JSON schema. - Used for models like Gemini that don't support this property. - - :param schema: JSON schema to modify in-place - """ - if not isinstance(schema, dict): - return - - # Remove additionalProperties at current level - schema.pop("additionalProperties", None) - - # Process nested structures recursively - for value in schema.values(): - if isinstance(value, dict): - remove_additional_properties(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - remove_additional_properties(item) - - def convert_boolean_to_string(schema: dict) -> None: - """ - Convert boolean type specifications to string in JSON schema. - - :param schema: JSON schema to modify in-place - """ - if not isinstance(schema, dict): - return - - # Check for boolean type at current level - if schema.get("type") == "boolean": - schema["type"] = "string" - - # Process nested dictionaries and lists recursively - for value in schema.values(): - if isinstance(value, dict): - convert_boolean_to_string(value) - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - convert_boolean_to_string(item) - - # Deep copy to avoid modifying the original schema - processed_schema = schema.copy() - - # Convert boolean types to string types (common requirement) - convert_boolean_to_string(processed_schema) - - # Apply model-specific transformations - if SpecialModelType.GEMINI in self.model: - remove_additional_properties(processed_schema) - return processed_schema - elif SpecialModelType.OLLAMA in self.provider: - return processed_schema - else: - # Default format with name field - return {"schema": processed_schema, "name": "llm_response"} - - def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None: - """ - Fetch model schema - """ - model_provider = ModelProviderFactory(self.model_type_instance.tenant_id) - return model_provider.get_model_schema( - provider=provider, model_type=model_type, model=model, credentials=self.credentials - ) - class ModelManager: def __init__(self) -> None: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index c924d351f2..3009178c2b 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel): @model_validator(mode="after") def validate_model(self): - supported_schema_keys = ["json_schema", "format"] + supported_schema_keys = ["json_schema"] schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) if schema_key: if self.features is None: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 698ad158cd..5ae68dc2f1 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -29,7 +29,13 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID @@ -59,6 +65,8 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) +from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType +from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -132,7 +140,6 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(event, RunRetrieverResourceEvent): context = event.context yield event - if context: node_inputs["#context#"] = context @@ -510,14 +517,6 @@ class LLMNode(BaseNode[LLMNodeData]): # model config completion_params = node_data_model.completion_params - if ( - isinstance(self.node_data, LLMNodeData) - and self.node_data.structured_output_enabled - and self.node_data.structured_output - ): - completion_params["structured_output_schema"] = json.dumps( - self.node_data.structured_output.get("schema", {}), ensure_ascii=False - ) stop = [] if "stop" in completion_params: stop = completion_params["stop"] @@ -532,7 +531,12 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ModelNotExistError(f"Model {model_name} not exist.") - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output: + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) + elif support_structured_output is False: + # Set appropriate response format based on model capabilities + self._set_response_format(completion_params, model_schema.parameter_rules) return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, model=model_name, @@ -743,7 +747,11 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output is False: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) stop = model_config.stop return filtered_prompt_messages, stop @@ -945,6 +953,167 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + :return: Updated model parameters with JSON schema configuration + """ + # Process schema according to model requirements + schema = self._fetch_structured_output_schema() + schema_json = self._prepare_schema_for_model(schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + + return model_parameters + + def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]: + """ + Handle structured output for models without native JSON schema support. + This function modifies the prompt messages to include schema-based output requirements. + + Args: + prompt_messages: Original sequence of prompt messages + + Returns: + list[PromptMessage]: Updated prompt messages with structured output requirements + """ + # Convert schema to string format + schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False) + + # Find existing system prompt with schema placeholder + system_prompt = next( + (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), + None, + ) + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) + # Prepare system prompt content + system_prompt_content = ( + structured_output_prompt + "\n\n" + system_prompt.content + if system_prompt and isinstance(system_prompt.content, str) + else structured_output_prompt + ) + system_prompt = SystemPromptMessage(content=system_prompt_content) + + # Extract content from the last user message + + filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] + updated_prompt = [system_prompt] + filtered_prompts + + return updated_prompt + + def _set_response_format(self, model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + def _prepare_schema_for_model(self, schema: dict) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + # Deep copy to avoid modifying the original schema + processed_schema = schema.copy() + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in self.node_data.model.name: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in self.node_data.model.provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + + def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: + """ + Fetch model schema + """ + model_name = self.node_data.model.name + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name + ) + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_credentials = model_instance.credentials + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_schema + + def _fetch_structured_output_schema(self) -> dict: + """ + Fetch the structured output schema for language models. + + This method retrieves and validates the JSON schema defined in the node data + for structured output formatting. It ensures the schema is properly formatted + and can be used for structured response generation. + + :return: The validated JSON schema as a dictionary + """ + # Extract and validate schema + if not self.node_data.structured_output: + raise LLMNodeError("Please provide a valid structured output schema") + structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + if not structured_output_schema: + raise LLMNodeError("Please provide a valid structured output schema") + + try: + schema = json.loads(structured_output_schema) + except json.JSONDecodeError: + raise LLMNodeError("structured_output_schema is not valid JSON format") + return schema + + def _check_model_structured_output_support(self) -> Optional[bool]: + """ + Check if the current model supports structured output. + + Returns: + Optional[bool]: + - True if model supports structured output + - False if model exists but doesn't support structured output + - None if structured output is disabled or model doesn't exist + """ + # Early return if structured output is disabled + if ( + not isinstance(self.node_data, LLMNodeData) + or not self.node_data.structured_output_enabled + or not self.node_data.structured_output + ): + return None + # Get model schema and check if it exists + model_schema = self._fetch_model_schema(self.node_data.model.provider) + if not model_schema: + return None + + # Check if model supports structured output feature + return bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -1083,3 +1252,49 @@ def _handle_completion_template( ) prompt_messages.append(prompt_message) return prompt_messages + + +def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + +def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py index bceadcbfab..2f0bd3d016 100644 --- a/api/core/workflow/utils/structured_output/prompt.py +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -1,10 +1,9 @@ -STRUCTURED_OUTPUT_PROMPT = """ -You’re a helpful AI assistant. You could answer questions and output in JSON format. +STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. constrant: - You must output in JSON format. - Do not output boolean value, use string type instead. - Do not output integer or float value, use number type instead. -eg1: +eg: Here is the JSON schema: {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} @@ -13,12 +12,6 @@ eg1: output: {"name": "John Doe", "age": 30} - Here is the JSON schema: {{schema}} - -Here is the user's question: -{{question}} -output: - """ # noqa: E501