chore: refactor code in llm node

This commit is contained in:
Novice 2025-04-08 13:46:55 +08:00
parent 3e31dae1fd
commit b3728bca52
4 changed files with 232 additions and 219 deletions

View File

@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
@ -10,8 +9,8 @@ from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
@ -21,10 +20,7 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.provider_manager import ProviderManager
from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from extensions.ext_redis import redis_client
from models.provider import ProviderType
@ -164,13 +160,6 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
if model_parameters and model_parameters.get("structured_output_schema"):
result = self._handle_structured_output(
model_parameters=model_parameters,
prompt=prompt_messages,
)
prompt_messages = result["prompt"]
model_parameters = result["parameters"]
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -421,190 +410,6 @@ class ModelInstance:
model=self.model, credentials=self.credentials, language=language
)
def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict:
"""
Handle structured output for language models.
This function processes schema-based structured outputs by either:
1. Using native JSON schema support of the model when available
2. Falling back to prompt-based structured output for models without native support
:param model_parameters: Model configuration parameters
:param prompt: Sequence of prompt messages
:return: Dictionary with updated prompt and parameters
"""
# Extract and validate schema
structured_output_schema = model_parameters.pop("structured_output_schema")
if not structured_output_schema:
raise ValueError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
except json.JSONDecodeError:
raise ValueError("structured_output_schema is not valid JSON format")
# Fetch model schema and validate
model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model)
if not model_schema:
raise ValueError("Unable to fetch model schema")
rules = model_schema.parameter_rules
rule_names = [rule.name for rule in rules]
# Handle based on model's JSON schema support capability
if "json_schema" in rule_names:
return self._handle_native_json_schema(schema, model_parameters, prompt, rules)
else:
return self._handle_prompt_based_schema(structured_output_schema, model_parameters, prompt, rules)
def _handle_native_json_schema(
self, schema: dict, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list
) -> dict:
"""
Handle structured output for models with native JSON schema support.
:param schema: Parsed JSON schema
:param model_parameters: Model parameters to update
:param prompt: Sequence of prompt messages
:param rules: Model parameter rules
:return: Updated prompt and parameters
"""
# Process schema according to model requirements
schema_json = self._prepare_schema_for_model(schema)
# Set JSON schema in parameters
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
return {"prompt": prompt, "parameters": model_parameters}
def _handle_prompt_based_schema(
self, schema_str: str, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list
) -> dict:
"""
Handle structured output for models without native JSON schema support.
:param schema_str: JSON schema as string
:param model_parameters: Model parameters to update
:param prompt: Sequence of prompt messages
:param rules: Model parameter rules
:return: Updated prompt and parameters
"""
# Extract content from the last prompt message
content = prompt[-1].content if isinstance(prompt[-1].content, str) else ""
# Create structured output prompt
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str).replace(
"{{question}}", content
)
# Replace the last user message with our structured output prompt
structured_prompt_message = UserPromptMessage(content=structured_output_prompt)
updated_prompt = list(prompt[:-1]) + [structured_prompt_message]
# Set appropriate response format based on model capabilities
self._set_response_format(model_parameters, rules)
return {"prompt": updated_prompt, "parameters": model_parameters}
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
"""
Set the appropriate response format parameter based on model rules.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
def _prepare_schema_for_model(self, schema: dict) -> dict:
"""
Prepare JSON schema based on model requirements.
Different models have different requirements for JSON schema formatting.
This function handles these differences.
:param schema: The original JSON schema
:return: Processed schema compatible with the current model
"""
def remove_additional_properties(schema: dict) -> None:
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict) -> None:
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)
# Deep copy to avoid modifying the original schema
processed_schema = schema.copy()
# Convert boolean types to string types (common requirement)
convert_boolean_to_string(processed_schema)
# Apply model-specific transformations
if SpecialModelType.GEMINI in self.model:
remove_additional_properties(processed_schema)
return processed_schema
elif SpecialModelType.OLLAMA in self.provider:
return processed_schema
else:
# Default format with name field
return {"schema": processed_schema, "name": "llm_response"}
def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_provider = ModelProviderFactory(self.model_type_instance.tenant_id)
return model_provider.get_model_schema(
provider=provider, model_type=model_type, model=model, credentials=self.credentials
)
class ModelManager:
def __init__(self) -> None:

View File

@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel):
@model_validator(mode="after")
def validate_model(self):
supported_schema_keys = ["json_schema", "format"]
supported_schema_keys = ["json_schema"]
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
if schema_key:
if self.features is None:

View File

@ -29,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
@ -59,6 +65,8 @@ from core.workflow.nodes.event import (
RunRetrieverResourceEvent,
RunStreamChunkEvent,
)
from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
@ -132,7 +140,6 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
if context:
node_inputs["#context#"] = context
@ -510,14 +517,6 @@ class LLMNode(BaseNode[LLMNodeData]):
# model config
completion_params = node_data_model.completion_params
if (
isinstance(self.node_data, LLMNodeData)
and self.node_data.structured_output_enabled
and self.node_data.structured_output
):
completion_params["structured_output_schema"] = json.dumps(
self.node_data.structured_output.get("schema", {}), ensure_ascii=False
)
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
@ -532,7 +531,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
elif support_structured_output is False:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
@ -743,7 +747,11 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
support_structured_output = self._check_model_structured_output_support()
if support_structured_output is False:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
stop = model_config.stop
return filtered_prompt_messages, stop
@ -945,6 +953,167 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
Handle structured output for models with native JSON schema support.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
:return: Updated model parameters with JSON schema configuration
"""
# Process schema according to model requirements
schema = self._fetch_structured_output_schema()
schema_json = self._prepare_schema_for_model(schema)
# Set JSON schema in parameters
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
return model_parameters
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
"""
Handle structured output for models without native JSON schema support.
This function modifies the prompt messages to include schema-based output requirements.
Args:
prompt_messages: Original sequence of prompt messages
Returns:
list[PromptMessage]: Updated prompt messages with structured output requirements
"""
# Convert schema to string format
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
# Find existing system prompt with schema placeholder
system_prompt = next(
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
None,
)
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
# Prepare system prompt content
system_prompt_content = (
structured_output_prompt + "\n\n" + system_prompt.content
if system_prompt and isinstance(system_prompt.content, str)
else structured_output_prompt
)
system_prompt = SystemPromptMessage(content=system_prompt_content)
# Extract content from the last user message
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
updated_prompt = [system_prompt] + filtered_prompts
return updated_prompt
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
"""
Set the appropriate response format parameter based on model rules.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
def _prepare_schema_for_model(self, schema: dict) -> dict:
"""
Prepare JSON schema based on model requirements.
Different models have different requirements for JSON schema formatting.
This function handles these differences.
:param schema: The original JSON schema
:return: Processed schema compatible with the current model
"""
# Deep copy to avoid modifying the original schema
processed_schema = schema.copy()
# Convert boolean types to string types (common requirement)
convert_boolean_to_string(processed_schema)
# Apply model-specific transformations
if SpecialModelType.GEMINI in self.node_data.model.name:
remove_additional_properties(processed_schema)
return processed_schema
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
return processed_schema
else:
# Default format with name field
return {"schema": processed_schema, "name": "llm_response"}
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_name = self.node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
def _fetch_structured_output_schema(self) -> dict:
"""
Fetch the structured output schema for language models.
This method retrieves and validates the JSON schema defined in the node data
for structured output formatting. It ensures the schema is properly formatted
and can be used for structured response generation.
:return: The validated JSON schema as a dictionary
"""
# Extract and validate schema
if not self.node_data.structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
return schema
def _check_model_structured_output_support(self) -> Optional[bool]:
"""
Check if the current model supports structured output.
Returns:
Optional[bool]:
- True if model supports structured output
- False if model exists but doesn't support structured output
- None if structured output is disabled or model doesn't exist
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return None
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return None
# Check if model supports structured output feature
return bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@ -1083,3 +1252,49 @@ def _handle_completion_template(
)
prompt_messages.append(prompt_message)
return prompt_messages
def remove_additional_properties(schema: dict) -> None:
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict) -> None:
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)

View File

@ -1,10 +1,9 @@
STRUCTURED_OUTPUT_PROMPT = """
Youre a helpful AI assistant. You could answer questions and output in JSON format.
STRUCTURED_OUTPUT_PROMPT = """Youre a helpful AI assistant. You could answer questions and output in JSON format.
constrant:
- You must output in JSON format.
- Do not output boolean value, use string type instead.
- Do not output integer or float value, use number type instead.
eg1:
eg:
Here is the JSON schema:
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
@ -13,12 +12,6 @@ eg1:
output:
{"name": "John Doe", "age": 30}
Here is the JSON schema:
{{schema}}
Here is the user's question:
{{question}}
output:
""" # noqa: E501