mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 10:15:55 +08:00
chore: refactor code in llm node
This commit is contained in:
parent
3e31dae1fd
commit
b3728bca52
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||||
@ -10,8 +9,8 @@ from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
|||||||
from core.errors.error import ProviderTokenNotInitError
|
from core.errors.error import ProviderTokenNotInitError
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||||
@ -21,10 +20,7 @@ from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
|||||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType
|
|
||||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
|
|
||||||
@ -164,13 +160,6 @@ class ModelInstance:
|
|||||||
raise Exception("Model type instance is not LargeLanguageModel")
|
raise Exception("Model type instance is not LargeLanguageModel")
|
||||||
|
|
||||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||||
if model_parameters and model_parameters.get("structured_output_schema"):
|
|
||||||
result = self._handle_structured_output(
|
|
||||||
model_parameters=model_parameters,
|
|
||||||
prompt=prompt_messages,
|
|
||||||
)
|
|
||||||
prompt_messages = result["prompt"]
|
|
||||||
model_parameters = result["parameters"]
|
|
||||||
return cast(
|
return cast(
|
||||||
Union[LLMResult, Generator],
|
Union[LLMResult, Generator],
|
||||||
self._round_robin_invoke(
|
self._round_robin_invoke(
|
||||||
@ -421,190 +410,6 @@ class ModelInstance:
|
|||||||
model=self.model, credentials=self.credentials, language=language
|
model=self.model, credentials=self.credentials, language=language
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_structured_output(self, model_parameters: dict, prompt: Sequence[PromptMessage]) -> dict:
|
|
||||||
"""
|
|
||||||
Handle structured output for language models.
|
|
||||||
|
|
||||||
This function processes schema-based structured outputs by either:
|
|
||||||
1. Using native JSON schema support of the model when available
|
|
||||||
2. Falling back to prompt-based structured output for models without native support
|
|
||||||
|
|
||||||
:param model_parameters: Model configuration parameters
|
|
||||||
:param prompt: Sequence of prompt messages
|
|
||||||
:return: Dictionary with updated prompt and parameters
|
|
||||||
"""
|
|
||||||
# Extract and validate schema
|
|
||||||
structured_output_schema = model_parameters.pop("structured_output_schema")
|
|
||||||
if not structured_output_schema:
|
|
||||||
raise ValueError("Please provide a valid structured output schema")
|
|
||||||
|
|
||||||
try:
|
|
||||||
schema = json.loads(structured_output_schema)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError("structured_output_schema is not valid JSON format")
|
|
||||||
|
|
||||||
# Fetch model schema and validate
|
|
||||||
model_schema = self._fetch_model_schema(self.provider, self.model_type_instance.model_type, self.model)
|
|
||||||
if not model_schema:
|
|
||||||
raise ValueError("Unable to fetch model schema")
|
|
||||||
|
|
||||||
rules = model_schema.parameter_rules
|
|
||||||
rule_names = [rule.name for rule in rules]
|
|
||||||
|
|
||||||
# Handle based on model's JSON schema support capability
|
|
||||||
if "json_schema" in rule_names:
|
|
||||||
return self._handle_native_json_schema(schema, model_parameters, prompt, rules)
|
|
||||||
else:
|
|
||||||
return self._handle_prompt_based_schema(structured_output_schema, model_parameters, prompt, rules)
|
|
||||||
|
|
||||||
def _handle_native_json_schema(
|
|
||||||
self, schema: dict, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Handle structured output for models with native JSON schema support.
|
|
||||||
|
|
||||||
:param schema: Parsed JSON schema
|
|
||||||
:param model_parameters: Model parameters to update
|
|
||||||
:param prompt: Sequence of prompt messages
|
|
||||||
:param rules: Model parameter rules
|
|
||||||
:return: Updated prompt and parameters
|
|
||||||
"""
|
|
||||||
# Process schema according to model requirements
|
|
||||||
schema_json = self._prepare_schema_for_model(schema)
|
|
||||||
|
|
||||||
# Set JSON schema in parameters
|
|
||||||
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
|
||||||
|
|
||||||
# Set appropriate response format if required by the model
|
|
||||||
for rule in rules:
|
|
||||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
|
||||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
|
||||||
|
|
||||||
return {"prompt": prompt, "parameters": model_parameters}
|
|
||||||
|
|
||||||
def _handle_prompt_based_schema(
|
|
||||||
self, schema_str: str, model_parameters: dict, prompt: Sequence[PromptMessage], rules: list
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Handle structured output for models without native JSON schema support.
|
|
||||||
|
|
||||||
:param schema_str: JSON schema as string
|
|
||||||
:param model_parameters: Model parameters to update
|
|
||||||
:param prompt: Sequence of prompt messages
|
|
||||||
:param rules: Model parameter rules
|
|
||||||
:return: Updated prompt and parameters
|
|
||||||
"""
|
|
||||||
# Extract content from the last prompt message
|
|
||||||
content = prompt[-1].content if isinstance(prompt[-1].content, str) else ""
|
|
||||||
|
|
||||||
# Create structured output prompt
|
|
||||||
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str).replace(
|
|
||||||
"{{question}}", content
|
|
||||||
)
|
|
||||||
|
|
||||||
# Replace the last user message with our structured output prompt
|
|
||||||
structured_prompt_message = UserPromptMessage(content=structured_output_prompt)
|
|
||||||
updated_prompt = list(prompt[:-1]) + [structured_prompt_message]
|
|
||||||
|
|
||||||
# Set appropriate response format based on model capabilities
|
|
||||||
self._set_response_format(model_parameters, rules)
|
|
||||||
|
|
||||||
return {"prompt": updated_prompt, "parameters": model_parameters}
|
|
||||||
|
|
||||||
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
|
||||||
"""
|
|
||||||
Set the appropriate response format parameter based on model rules.
|
|
||||||
|
|
||||||
:param model_parameters: Model parameters to update
|
|
||||||
:param rules: Model parameter rules
|
|
||||||
"""
|
|
||||||
for rule in rules:
|
|
||||||
if rule.name == "response_format":
|
|
||||||
if ResponseFormat.JSON.value in rule.options:
|
|
||||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
|
||||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
|
||||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
|
||||||
|
|
||||||
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Prepare JSON schema based on model requirements.
|
|
||||||
|
|
||||||
Different models have different requirements for JSON schema formatting.
|
|
||||||
This function handles these differences.
|
|
||||||
|
|
||||||
:param schema: The original JSON schema
|
|
||||||
:return: Processed schema compatible with the current model
|
|
||||||
"""
|
|
||||||
|
|
||||||
def remove_additional_properties(schema: dict) -> None:
|
|
||||||
"""
|
|
||||||
Remove additionalProperties fields from JSON schema.
|
|
||||||
Used for models like Gemini that don't support this property.
|
|
||||||
|
|
||||||
:param schema: JSON schema to modify in-place
|
|
||||||
"""
|
|
||||||
if not isinstance(schema, dict):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove additionalProperties at current level
|
|
||||||
schema.pop("additionalProperties", None)
|
|
||||||
|
|
||||||
# Process nested structures recursively
|
|
||||||
for value in schema.values():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
remove_additional_properties(value)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
remove_additional_properties(item)
|
|
||||||
|
|
||||||
def convert_boolean_to_string(schema: dict) -> None:
|
|
||||||
"""
|
|
||||||
Convert boolean type specifications to string in JSON schema.
|
|
||||||
|
|
||||||
:param schema: JSON schema to modify in-place
|
|
||||||
"""
|
|
||||||
if not isinstance(schema, dict):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for boolean type at current level
|
|
||||||
if schema.get("type") == "boolean":
|
|
||||||
schema["type"] = "string"
|
|
||||||
|
|
||||||
# Process nested dictionaries and lists recursively
|
|
||||||
for value in schema.values():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
convert_boolean_to_string(value)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
convert_boolean_to_string(item)
|
|
||||||
|
|
||||||
# Deep copy to avoid modifying the original schema
|
|
||||||
processed_schema = schema.copy()
|
|
||||||
|
|
||||||
# Convert boolean types to string types (common requirement)
|
|
||||||
convert_boolean_to_string(processed_schema)
|
|
||||||
|
|
||||||
# Apply model-specific transformations
|
|
||||||
if SpecialModelType.GEMINI in self.model:
|
|
||||||
remove_additional_properties(processed_schema)
|
|
||||||
return processed_schema
|
|
||||||
elif SpecialModelType.OLLAMA in self.provider:
|
|
||||||
return processed_schema
|
|
||||||
else:
|
|
||||||
# Default format with name field
|
|
||||||
return {"schema": processed_schema, "name": "llm_response"}
|
|
||||||
|
|
||||||
def _fetch_model_schema(self, provider: str, model_type: ModelType, model: str) -> AIModelEntity | None:
|
|
||||||
"""
|
|
||||||
Fetch model schema
|
|
||||||
"""
|
|
||||||
model_provider = ModelProviderFactory(self.model_type_instance.tenant_id)
|
|
||||||
return model_provider.get_model_schema(
|
|
||||||
provider=provider, model_type=model_type, model=model, credentials=self.credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel):
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_model(self):
|
def validate_model(self):
|
||||||
supported_schema_keys = ["json_schema", "format"]
|
supported_schema_keys = ["json_schema"]
|
||||||
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||||
if schema_key:
|
if schema_key:
|
||||||
if self.features is None:
|
if self.features is None:
|
||||||
|
@ -29,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
@ -59,6 +65,8 @@ from core.workflow.nodes.event import (
|
|||||||
RunRetrieverResourceEvent,
|
RunRetrieverResourceEvent,
|
||||||
RunStreamChunkEvent,
|
RunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.utils.structured_output.entities import ResponseFormat, SpecialModelType
|
||||||
|
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
@ -132,7 +140,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
if isinstance(event, RunRetrieverResourceEvent):
|
if isinstance(event, RunRetrieverResourceEvent):
|
||||||
context = event.context
|
context = event.context
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
node_inputs["#context#"] = context
|
node_inputs["#context#"] = context
|
||||||
|
|
||||||
@ -510,14 +517,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
# model config
|
# model config
|
||||||
completion_params = node_data_model.completion_params
|
completion_params = node_data_model.completion_params
|
||||||
if (
|
|
||||||
isinstance(self.node_data, LLMNodeData)
|
|
||||||
and self.node_data.structured_output_enabled
|
|
||||||
and self.node_data.structured_output
|
|
||||||
):
|
|
||||||
completion_params["structured_output_schema"] = json.dumps(
|
|
||||||
self.node_data.structured_output.get("schema", {}), ensure_ascii=False
|
|
||||||
)
|
|
||||||
stop = []
|
stop = []
|
||||||
if "stop" in completion_params:
|
if "stop" in completion_params:
|
||||||
stop = completion_params["stop"]
|
stop = completion_params["stop"]
|
||||||
@ -532,7 +531,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||||
|
support_structured_output = self._check_model_structured_output_support()
|
||||||
|
if support_structured_output:
|
||||||
|
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||||
|
elif support_structured_output is False:
|
||||||
|
# Set appropriate response format based on model capabilities
|
||||||
|
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||||
return model_instance, ModelConfigWithCredentialsEntity(
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -743,7 +747,11 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
support_structured_output = self._check_model_structured_output_support()
|
||||||
|
if support_structured_output is False:
|
||||||
|
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||||
|
prompt_messages=filtered_prompt_messages,
|
||||||
|
)
|
||||||
stop = model_config.stop
|
stop = model_config.stop
|
||||||
return filtered_prompt_messages, stop
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
@ -945,6 +953,167 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||||
|
"""
|
||||||
|
Handle structured output for models with native JSON schema support.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
:return: Updated model parameters with JSON schema configuration
|
||||||
|
"""
|
||||||
|
# Process schema according to model requirements
|
||||||
|
schema = self._fetch_structured_output_schema()
|
||||||
|
schema_json = self._prepare_schema_for_model(schema)
|
||||||
|
|
||||||
|
# Set JSON schema in parameters
|
||||||
|
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Set appropriate response format if required by the model
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||||
|
|
||||||
|
return model_parameters
|
||||||
|
|
||||||
|
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Handle structured output for models without native JSON schema support.
|
||||||
|
This function modifies the prompt messages to include schema-based output requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_messages: Original sequence of prompt messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||||
|
"""
|
||||||
|
# Convert schema to string format
|
||||||
|
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
||||||
|
|
||||||
|
# Find existing system prompt with schema placeholder
|
||||||
|
system_prompt = next(
|
||||||
|
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||||
|
# Prepare system prompt content
|
||||||
|
system_prompt_content = (
|
||||||
|
structured_output_prompt + "\n\n" + system_prompt.content
|
||||||
|
if system_prompt and isinstance(system_prompt.content, str)
|
||||||
|
else structured_output_prompt
|
||||||
|
)
|
||||||
|
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||||
|
|
||||||
|
# Extract content from the last user message
|
||||||
|
|
||||||
|
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||||
|
updated_prompt = [system_prompt] + filtered_prompts
|
||||||
|
|
||||||
|
return updated_prompt
|
||||||
|
|
||||||
|
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
||||||
|
"""
|
||||||
|
Set the appropriate response format parameter based on model rules.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
"""
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format":
|
||||||
|
if ResponseFormat.JSON.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||||
|
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||||
|
|
||||||
|
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Prepare JSON schema based on model requirements.
|
||||||
|
|
||||||
|
Different models have different requirements for JSON schema formatting.
|
||||||
|
This function handles these differences.
|
||||||
|
|
||||||
|
:param schema: The original JSON schema
|
||||||
|
:return: Processed schema compatible with the current model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Deep copy to avoid modifying the original schema
|
||||||
|
processed_schema = schema.copy()
|
||||||
|
|
||||||
|
# Convert boolean types to string types (common requirement)
|
||||||
|
convert_boolean_to_string(processed_schema)
|
||||||
|
|
||||||
|
# Apply model-specific transformations
|
||||||
|
if SpecialModelType.GEMINI in self.node_data.model.name:
|
||||||
|
remove_additional_properties(processed_schema)
|
||||||
|
return processed_schema
|
||||||
|
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
||||||
|
return processed_schema
|
||||||
|
else:
|
||||||
|
# Default format with name field
|
||||||
|
return {"schema": processed_schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
Fetch model schema
|
||||||
|
"""
|
||||||
|
model_name = self.node_data.model.name
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||||
|
)
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
model_credentials = model_instance.credentials
|
||||||
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
|
return model_schema
|
||||||
|
|
||||||
|
def _fetch_structured_output_schema(self) -> dict:
|
||||||
|
"""
|
||||||
|
Fetch the structured output schema for language models.
|
||||||
|
|
||||||
|
This method retrieves and validates the JSON schema defined in the node data
|
||||||
|
for structured output formatting. It ensures the schema is properly formatted
|
||||||
|
and can be used for structured response generation.
|
||||||
|
|
||||||
|
:return: The validated JSON schema as a dictionary
|
||||||
|
"""
|
||||||
|
# Extract and validate schema
|
||||||
|
if not self.node_data.structured_output:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
|
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
||||||
|
if not structured_output_schema:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = json.loads(structured_output_schema)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def _check_model_structured_output_support(self) -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Check if the current model supports structured output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[bool]:
|
||||||
|
- True if model supports structured output
|
||||||
|
- False if model exists but doesn't support structured output
|
||||||
|
- None if structured output is disabled or model doesn't exist
|
||||||
|
"""
|
||||||
|
# Early return if structured output is disabled
|
||||||
|
if (
|
||||||
|
not isinstance(self.node_data, LLMNodeData)
|
||||||
|
or not self.node_data.structured_output_enabled
|
||||||
|
or not self.node_data.structured_output
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
# Get model schema and check if it exists
|
||||||
|
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||||
|
if not model_schema:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if model supports structured output feature
|
||||||
|
return bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||||
match role:
|
match role:
|
||||||
@ -1083,3 +1252,49 @@ def _handle_completion_template(
|
|||||||
)
|
)
|
||||||
prompt_messages.append(prompt_message)
|
prompt_messages.append(prompt_message)
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def remove_additional_properties(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Remove additionalProperties fields from JSON schema.
|
||||||
|
Used for models like Gemini that don't support this property.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove additionalProperties at current level
|
||||||
|
schema.pop("additionalProperties", None)
|
||||||
|
|
||||||
|
# Process nested structures recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
remove_additional_properties(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
remove_additional_properties(item)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_boolean_to_string(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Convert boolean type specifications to string in JSON schema.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for boolean type at current level
|
||||||
|
if schema.get("type") == "boolean":
|
||||||
|
schema["type"] = "string"
|
||||||
|
|
||||||
|
# Process nested dictionaries and lists recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
convert_boolean_to_string(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
convert_boolean_to_string(item)
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
STRUCTURED_OUTPUT_PROMPT = """
|
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||||
You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
|
||||||
constrant:
|
constrant:
|
||||||
- You must output in JSON format.
|
- You must output in JSON format.
|
||||||
- Do not output boolean value, use string type instead.
|
- Do not output boolean value, use string type instead.
|
||||||
- Do not output integer or float value, use number type instead.
|
- Do not output integer or float value, use number type instead.
|
||||||
eg1:
|
eg:
|
||||||
Here is the JSON schema:
|
Here is the JSON schema:
|
||||||
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||||
|
|
||||||
@ -13,12 +12,6 @@ eg1:
|
|||||||
|
|
||||||
output:
|
output:
|
||||||
{"name": "John Doe", "age": 30}
|
{"name": "John Doe", "age": 30}
|
||||||
|
|
||||||
Here is the JSON schema:
|
Here is the JSON schema:
|
||||||
{{schema}}
|
{{schema}}
|
||||||
|
|
||||||
Here is the user's question:
|
|
||||||
{{question}}
|
|
||||||
output:
|
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
Loading…
x
Reference in New Issue
Block a user