From e01653cfee3d4cf3798fcab9a8c2d823cfff5a7f Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 1 Apr 2025 09:13:57 +0800 Subject: [PATCH] fix: handle boolean type value --- api/core/model_manager.py | 193 ++++++++++++++---- .../utils/structured_output/entities.py | 16 ++ .../utils/structured_output/prompt.py | 5 +- 3 files changed, 174 insertions(+), 40 deletions(-) create mode 100644 api/core/workflow/utils/structured_output/entities.py diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 49ac71acd2..18025df39e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager +from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from extensions.ext_redis import redis_client from models.provider import ProviderType @@ -422,12 +423,17 @@ class ModelInstance: def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict: """ - Handle structured output + Handle structured output for language models. - :param model_parameters: model parameters - :param provider: provider name - :return: updated model parameters + This function processes schema-based structured outputs by either: + 1. Using native JSON schema support of the model when available + 2. Falling back to prompt-based structured output for models without native support + + :param model_parameters: Model configuration parameters + :param prompt: Sequence of prompt messages + :return: Dictionary with updated prompt and parameters """ + # Extract and validate schema structured_output_schema = model_parameters.pop("structured_output_schema") if not structured_output_schema: raise ValueError("Please provide a valid structured output schema") @@ -437,50 +443,159 @@ class ModelInstance: except json.JSONDecodeError: raise ValueError("structured_output_schema is not valid JSON format") + # Fetch model schema and validate model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model) if not model_schema: raise ValueError("Unable to fetch model schema") + rules = model_schema.parameter_rules - if "json_schema" in [rule.name for rule in rules]: - name = {"name": "llm_response"} - if "gemini" in self.model: + rule_names = [rule.name for rule in rules] - def remove_additional_properties(schema): - if isinstance(schema, dict): - for key, value in list(schema.items()): - if key == "additionalProperties": - del schema[key] - else: - remove_additional_properties(value) - - remove_additional_properties(schema) - schema_json = schema - elif "ollama" in self.provider: - schema_json = schema - else: - schema_json = {"schema": schema, **name} - - model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) - for rule in rules: - if rule.name == "response_format" and "json_schema" in rule.options: - model_parameters["response_format"] = "json_schema" + # Handle based on model's JSON schema support capability + if "json_schema" in rule_names: + return self._handle_native_json_schema(schema, model_parameters, prompt, rules) else: - content = prompt[-1].content if isinstance(prompt[-1].content, str) else "" - structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", structured_output_schema).replace( - "{{question}}", content - ) - structured_output_prompt_message = UserPromptMessage(content=structured_output_prompt) - prompt = list(prompt[:-1]) + [structured_output_prompt_message] - for rule in rules: - if rule.name == "response_format": - if "JSON" in rule.options: - model_parameters["response_format"] = "JSON" - elif "json_object" in rule.options: - model_parameters["response_format"] = "json_object" - return {"prompt": prompt, "parameters": model_parameters} + return self._handle_prompt_based_schema(structured_output_schema, model_parameters, prompt, rules) + + def _handle_native_json_schema( + self, schema: dict, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list + ) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param schema: Parsed JSON schema + :param model_parameters: Model parameters to update + :param prompt: Sequence of prompt messages + :param rules: Model parameter rules + :return: Updated prompt and parameters + """ + # Process schema according to model requirements + schema_json = self._prepare_schema_for_model(schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value return {"prompt": prompt, "parameters": model_parameters} + def _handle_prompt_based_schema( + self, schema_str: str, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list + ) -> dict: + """ + Handle structured output for models without native JSON schema support. + + :param schema_str: JSON schema as string + :param model_parameters: Model parameters to update + :param prompt: Sequence of prompt messages + :param rules: Model parameter rules + :return: Updated prompt and parameters + """ + # Extract content from the last prompt message + content = prompt[-1].content if isinstance(prompt[-1].content, str) else "" + + # Create structured output prompt + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str).replace( + "{{question}}", content + ) + + # Replace the last user message with our structured output prompt + structured_prompt_message = UserPromptMessage(content=structured_output_prompt) + updated_prompt = list(prompt[:-1]) + [structured_prompt_message] + + # Set appropriate response format based on model capabilities + self._set_response_format(model_parameters, rules) + + return {"prompt": updated_prompt, "parameters": model_parameters} + + def _set_response_format(self, model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + def _prepare_schema_for_model(self, schema: dict) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) + + # Deep copy to avoid modifying the original schema + processed_schema = schema.copy() + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in self.model: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in self.provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None: """ Fetch model schema diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py new file mode 100644 index 0000000000..ecdaca4ed2 --- /dev/null +++ b/api/core/workflow/utils/structured_output/entities.py @@ -0,0 +1,16 @@ +from enum import StrEnum + + +class ResponseFormat(StrEnum): + """Constants for model response formats""" + + JSON_SCHEMA = "json_schema" + JSON = "JSON" + JSON_OBJECT = "json_object" + + +class SpecialModelType(StrEnum): + """Constants for identifying model types""" + + GEMINI = "gemini" + OLLAMA = "ollama" diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py index 4fee7bb34d..bceadcbfab 100644 --- a/api/core/workflow/utils/structured_output/prompt.py +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -1,6 +1,9 @@ STRUCTURED_OUTPUT_PROMPT = """ You’re a helpful AI assistant. You could answer questions and output in JSON format. - +constrant: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. eg1: Here is the JSON schema: {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}