mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 23:25:54 +08:00
feat: structured output (#17877)
This commit is contained in:
parent
d2e3744ca3
commit
da9269ca97
@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
return code_result
|
return code_result
|
||||||
|
|
||||||
|
|
||||||
|
class RuleStructuredOutputGenerateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
account = current_user
|
||||||
|
try:
|
||||||
|
structured_output = LLMGenerator.generate_structured_output(
|
||||||
|
tenant_id=account.current_tenant_id,
|
||||||
|
instruction=args["instruction"],
|
||||||
|
model_config=args["model_config"],
|
||||||
|
)
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
|
||||||
|
return structured_output
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||||
|
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
|
||||||
|
@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
|
|||||||
GENERATOR_QA_PROMPT,
|
GENERATOR_QA_PROMPT,
|
||||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||||
|
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||||
)
|
)
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
@ -340,3 +341,37 @@ class LLMGenerator:
|
|||||||
|
|
||||||
answer = cast(str, response.message.content)
|
answer = cast(str, response.message.content)
|
||||||
return answer.strip()
|
return answer.strip()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
provider=model_config.get("provider", ""),
|
||||||
|
model=model_config.get("name", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_messages = [
|
||||||
|
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||||
|
UserPromptMessage(content=instruction),
|
||||||
|
]
|
||||||
|
model_parameters = model_config.get("model_parameters", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = cast(
|
||||||
|
LLMResult,
|
||||||
|
model_instance.invoke_llm(
|
||||||
|
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_json_schema = cast(str, response.message.content)
|
||||||
|
return {"output": generated_json_schema, "error": ""}
|
||||||
|
|
||||||
|
except InvokeError as e:
|
||||||
|
error = str(e)
|
||||||
|
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
|
||||||
|
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||||
|
@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}}
|
|||||||
|
|
||||||
You just need to generate the output
|
You just need to generate the output
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
|
||||||
|
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
|
||||||
|
|
||||||
|
## Instructions:
|
||||||
|
|
||||||
|
1. Analyze the user's description of their data needs
|
||||||
|
2. Identify each property that should be included in the schema
|
||||||
|
3. Determine the appropriate data type for each property
|
||||||
|
4. Decide which properties should be required
|
||||||
|
5. Generate a complete JSON Schema with proper syntax
|
||||||
|
6. Include appropriate constraints when specified (min/max values, patterns, formats)
|
||||||
|
7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
|
||||||
|
8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
|
||||||
|
|
||||||
|
## Examples:
|
||||||
|
|
||||||
|
### Example 1:
|
||||||
|
**User Input:** I need name and age
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" },
|
||||||
|
"age": { "type": "number" }
|
||||||
|
},
|
||||||
|
"required": ["name", "age"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Example 2:
|
||||||
|
**User Input:** I want to store information about books including title, author, publication year and optional page count
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": { "type": "string" },
|
||||||
|
"author": { "type": "string" },
|
||||||
|
"publicationYear": { "type": "integer" },
|
||||||
|
"pageCount": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": ["title", "author", "publicationYear"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Example 3:
|
||||||
|
**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"email": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "email"
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 8
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 18
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["email", "password", "age"]
|
||||||
|
}
|
||||||
|
|
||||||
|
### Example 4:
|
||||||
|
**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
|
||||||
|
**JSON Schema Output:**
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"properties": {
|
||||||
|
"songs": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"duration": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"aritst": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"name",
|
||||||
|
"id",
|
||||||
|
"duration",
|
||||||
|
"aritst"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"songs"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Now, generate a JSON Schema based on my description
|
||||||
|
""" # noqa: E501
|
||||||
|
@ -2,7 +2,7 @@ from decimal import Decimal
|
|||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, model_validator
|
||||||
|
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
|
||||||
@ -85,6 +85,7 @@ class ModelFeature(Enum):
|
|||||||
DOCUMENT = "document"
|
DOCUMENT = "document"
|
||||||
VIDEO = "video"
|
VIDEO = "video"
|
||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
|
STRUCTURED_OUTPUT = "structured-output"
|
||||||
|
|
||||||
|
|
||||||
class DefaultParameterName(StrEnum):
|
class DefaultParameterName(StrEnum):
|
||||||
@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel):
|
|||||||
parameter_rules: list[ParameterRule] = []
|
parameter_rules: list[ParameterRule] = []
|
||||||
pricing: Optional[PriceConfig] = None
|
pricing: Optional[PriceConfig] = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_model(self):
|
||||||
|
supported_schema_keys = ["json_schema"]
|
||||||
|
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
|
||||||
|
if not schema_key:
|
||||||
|
return self
|
||||||
|
if self.features is None:
|
||||||
|
self.features = [ModelFeature.STRUCTURED_OUTPUT]
|
||||||
|
else:
|
||||||
|
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
|
||||||
|
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ModelUsage(BaseModel):
|
class ModelUsage(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
@ -16,7 +16,7 @@ from core.variables.segments import StringSegment
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
|
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||||
@ -251,7 +251,12 @@ class AgentNode(ToolNode):
|
|||||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||||
]
|
]
|
||||||
value["history_prompt_messages"] = history_prompt_messages
|
value["history_prompt_messages"] = history_prompt_messages
|
||||||
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
|
if model_schema:
|
||||||
|
# remove structured output feature to support old version agent plugin
|
||||||
|
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||||
|
value["entity"] = model_schema.model_dump(mode="json")
|
||||||
|
else:
|
||||||
|
value["entity"] = None
|
||||||
result[parameter_name] = value
|
result[parameter_name] = value
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -348,3 +353,10 @@ class AgentNode(ToolNode):
|
|||||||
)
|
)
|
||||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
return model_instance, model_schema
|
return model_instance, model_schema
|
||||||
|
|
||||||
|
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||||
|
if model_schema.features:
|
||||||
|
for feature in model_schema.features:
|
||||||
|
if feature.value not in AgentOldVersionModelFeatures:
|
||||||
|
model_schema.features.remove(feature)
|
||||||
|
return model_schema
|
||||||
|
@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData):
|
|||||||
class ParamsAutoGenerated(Enum):
|
class ParamsAutoGenerated(Enum):
|
||||||
CLOSE = 0
|
CLOSE = 0
|
||||||
OPEN = 1
|
OPEN = 1
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOldVersionModelFeatures(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for old SDK version llm feature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOOL_CALL = "tool-call"
|
||||||
|
MULTI_TOOL_CALL = "multi-tool-call"
|
||||||
|
AGENT_THOUGHT = "agent-thought"
|
||||||
|
VISION = "vision"
|
||||||
|
STREAM_TOOL_CALL = "stream-tool-call"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
|
|||||||
memory: Optional[MemoryConfig] = None
|
memory: Optional[MemoryConfig] = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
|
structured_output: dict | None = None
|
||||||
|
structured_output_enabled: bool = False
|
||||||
|
|
||||||
@field_validator("prompt_config", mode="before")
|
@field_validator("prompt_config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
@ -57,6 +65,12 @@ from core.workflow.nodes.event import (
|
|||||||
RunRetrieverResourceEvent,
|
RunRetrieverResourceEvent,
|
||||||
RunStreamChunkEvent,
|
RunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
|
from core.workflow.utils.structured_output.entities import (
|
||||||
|
ResponseFormat,
|
||||||
|
SpecialModelType,
|
||||||
|
SupportStructuredOutputStatus,
|
||||||
|
)
|
||||||
|
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Conversation
|
from models.model import Conversation
|
||||||
@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
_node_type = NodeType.LLM
|
_node_type = NodeType.LLM
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||||
|
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||||
|
"""Process structured output if enabled"""
|
||||||
|
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
|
||||||
|
return None
|
||||||
|
return self._parse_structured_output(text)
|
||||||
|
|
||||||
node_inputs: Optional[dict[str, Any]] = None
|
node_inputs: Optional[dict[str, Any]] = None
|
||||||
process_data = None
|
process_data = None
|
||||||
result_text = ""
|
result_text = ""
|
||||||
@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
if isinstance(event, RunRetrieverResourceEvent):
|
if isinstance(event, RunRetrieverResourceEvent):
|
||||||
context = event.context
|
context = event.context
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
node_inputs["#context#"] = context
|
node_inputs["#context#"] = context
|
||||||
|
|
||||||
@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
break
|
break
|
||||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||||
|
structured_output = process_structured_output(result_text)
|
||||||
|
if structured_output:
|
||||||
|
outputs["structured_output"] = structured_output
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||||
|
support_structured_output = self._check_model_structured_output_support()
|
||||||
|
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
||||||
|
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||||
|
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||||
|
# Set appropriate response format based on model capabilities
|
||||||
|
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||||
return model_instance, ModelConfigWithCredentialsEntity(
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
support_structured_output = self._check_model_structured_output_support()
|
||||||
|
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||||
|
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||||
|
prompt_messages=filtered_prompt_messages,
|
||||||
|
)
|
||||||
stop = model_config.stop
|
stop = model_config.stop
|
||||||
return filtered_prompt_messages, stop
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
|
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
|
||||||
|
structured_output: dict[str, Any] | list[Any] = {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(result_text)
|
||||||
|
if not isinstance(parsed, (dict | list)):
|
||||||
|
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = parsed
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# if the result_text is not a valid json, try to repair it
|
||||||
|
parsed = json_repair.loads(result_text)
|
||||||
|
if not isinstance(parsed, (dict | list)):
|
||||||
|
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = parsed
|
||||||
|
return structured_output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||||
provider_model_bundle = model_instance.provider_model_bundle
|
provider_model_bundle = model_instance.provider_model_bundle
|
||||||
@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||||
|
"""
|
||||||
|
Handle structured output for models with native JSON schema support.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
:return: Updated model parameters with JSON schema configuration
|
||||||
|
"""
|
||||||
|
# Process schema according to model requirements
|
||||||
|
schema = self._fetch_structured_output_schema()
|
||||||
|
schema_json = self._prepare_schema_for_model(schema)
|
||||||
|
|
||||||
|
# Set JSON schema in parameters
|
||||||
|
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Set appropriate response format if required by the model
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||||
|
|
||||||
|
return model_parameters
|
||||||
|
|
||||||
|
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Handle structured output for models without native JSON schema support.
|
||||||
|
This function modifies the prompt messages to include schema-based output requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_messages: Original sequence of prompt messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PromptMessage]: Updated prompt messages with structured output requirements
|
||||||
|
"""
|
||||||
|
# Convert schema to string format
|
||||||
|
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
|
||||||
|
|
||||||
|
# Find existing system prompt with schema placeholder
|
||||||
|
system_prompt = next(
|
||||||
|
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
|
||||||
|
# Prepare system prompt content
|
||||||
|
system_prompt_content = (
|
||||||
|
structured_output_prompt + "\n\n" + system_prompt.content
|
||||||
|
if system_prompt and isinstance(system_prompt.content, str)
|
||||||
|
else structured_output_prompt
|
||||||
|
)
|
||||||
|
system_prompt = SystemPromptMessage(content=system_prompt_content)
|
||||||
|
|
||||||
|
# Extract content from the last user message
|
||||||
|
|
||||||
|
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
|
||||||
|
updated_prompt = [system_prompt] + filtered_prompts
|
||||||
|
|
||||||
|
return updated_prompt
|
||||||
|
|
||||||
|
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
|
||||||
|
"""
|
||||||
|
Set the appropriate response format parameter based on model rules.
|
||||||
|
|
||||||
|
:param model_parameters: Model parameters to update
|
||||||
|
:param rules: Model parameter rules
|
||||||
|
"""
|
||||||
|
for rule in rules:
|
||||||
|
if rule.name == "response_format":
|
||||||
|
if ResponseFormat.JSON.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||||
|
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||||
|
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||||
|
|
||||||
|
def _prepare_schema_for_model(self, schema: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Prepare JSON schema based on model requirements.
|
||||||
|
|
||||||
|
Different models have different requirements for JSON schema formatting.
|
||||||
|
This function handles these differences.
|
||||||
|
|
||||||
|
:param schema: The original JSON schema
|
||||||
|
:return: Processed schema compatible with the current model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Deep copy to avoid modifying the original schema
|
||||||
|
processed_schema = schema.copy()
|
||||||
|
|
||||||
|
# Convert boolean types to string types (common requirement)
|
||||||
|
convert_boolean_to_string(processed_schema)
|
||||||
|
|
||||||
|
# Apply model-specific transformations
|
||||||
|
if SpecialModelType.GEMINI in self.node_data.model.name:
|
||||||
|
remove_additional_properties(processed_schema)
|
||||||
|
return processed_schema
|
||||||
|
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
|
||||||
|
return processed_schema
|
||||||
|
else:
|
||||||
|
# Default format with name field
|
||||||
|
return {"schema": processed_schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
Fetch model schema
|
||||||
|
"""
|
||||||
|
model_name = self.node_data.model.name
|
||||||
|
model_manager = ModelManager()
|
||||||
|
model_instance = model_manager.get_model_instance(
|
||||||
|
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||||
|
)
|
||||||
|
model_type_instance = model_instance.model_type_instance
|
||||||
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||||
|
model_credentials = model_instance.credentials
|
||||||
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
|
return model_schema
|
||||||
|
|
||||||
|
def _fetch_structured_output_schema(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Fetch the structured output schema from the node data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: The structured output schema
|
||||||
|
"""
|
||||||
|
if not self.node_data.structured_output:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
|
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
||||||
|
if not structured_output_schema:
|
||||||
|
raise LLMNodeError("Please provide a valid structured output schema")
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = json.loads(structured_output_schema)
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||||
|
return schema
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||||
|
|
||||||
|
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
||||||
|
"""
|
||||||
|
Check if the current model supports structured output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SupportStructuredOutput: The support status of structured output
|
||||||
|
"""
|
||||||
|
# Early return if structured output is disabled
|
||||||
|
if (
|
||||||
|
not isinstance(self.node_data, LLMNodeData)
|
||||||
|
or not self.node_data.structured_output_enabled
|
||||||
|
or not self.node_data.structured_output
|
||||||
|
):
|
||||||
|
return SupportStructuredOutputStatus.DISABLED
|
||||||
|
# Get model schema and check if it exists
|
||||||
|
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||||
|
if not model_schema:
|
||||||
|
return SupportStructuredOutputStatus.DISABLED
|
||||||
|
|
||||||
|
# Check if model supports structured output feature
|
||||||
|
return (
|
||||||
|
SupportStructuredOutputStatus.SUPPORTED
|
||||||
|
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||||
|
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
|
||||||
match role:
|
match role:
|
||||||
@ -1064,3 +1269,49 @@ def _handle_completion_template(
|
|||||||
)
|
)
|
||||||
prompt_messages.append(prompt_message)
|
prompt_messages.append(prompt_message)
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def remove_additional_properties(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Remove additionalProperties fields from JSON schema.
|
||||||
|
Used for models like Gemini that don't support this property.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove additionalProperties at current level
|
||||||
|
schema.pop("additionalProperties", None)
|
||||||
|
|
||||||
|
# Process nested structures recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
remove_additional_properties(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
remove_additional_properties(item)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_boolean_to_string(schema: dict) -> None:
|
||||||
|
"""
|
||||||
|
Convert boolean type specifications to string in JSON schema.
|
||||||
|
|
||||||
|
:param schema: JSON schema to modify in-place
|
||||||
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for boolean type at current level
|
||||||
|
if schema.get("type") == "boolean":
|
||||||
|
schema["type"] = "string"
|
||||||
|
|
||||||
|
# Process nested dictionaries and lists recursively
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
convert_boolean_to_string(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
convert_boolean_to_string(item)
|
||||||
|
24
api/core/workflow/utils/structured_output/entities.py
Normal file
24
api/core/workflow/utils/structured_output/entities.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(StrEnum):
|
||||||
|
"""Constants for model response formats"""
|
||||||
|
|
||||||
|
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||||
|
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||||
|
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialModelType(StrEnum):
|
||||||
|
"""Constants for identifying model types"""
|
||||||
|
|
||||||
|
GEMINI = "gemini"
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
|
||||||
|
|
||||||
|
class SupportStructuredOutputStatus(StrEnum):
|
||||||
|
"""Constants for structured output support status"""
|
||||||
|
|
||||||
|
SUPPORTED = "supported"
|
||||||
|
UNSUPPORTED = "unsupported"
|
||||||
|
DISABLED = "disabled"
|
17
api/core/workflow/utils/structured_output/prompt.py
Normal file
17
api/core/workflow/utils/structured_output/prompt.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||||
|
constraints:
|
||||||
|
- You must output in JSON format.
|
||||||
|
- Do not output boolean value, use string type instead.
|
||||||
|
- Do not output integer or float value, use number type instead.
|
||||||
|
eg:
|
||||||
|
Here is the JSON schema:
|
||||||
|
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
|
||||||
|
|
||||||
|
Here is the user's question:
|
||||||
|
My name is John Doe and I am 30 years old.
|
||||||
|
|
||||||
|
output:
|
||||||
|
{"name": "John Doe", "age": 30}
|
||||||
|
Here is the JSON schema:
|
||||||
|
{{schema}}
|
||||||
|
""" # noqa: E501
|
@ -30,6 +30,7 @@ dependencies = [
|
|||||||
"gunicorn~=23.0.0",
|
"gunicorn~=23.0.0",
|
||||||
"httpx[socks]~=0.27.0",
|
"httpx[socks]~=0.27.0",
|
||||||
"jieba==0.42.1",
|
"jieba==0.42.1",
|
||||||
|
"json-repair>=0.41.1",
|
||||||
"langfuse~=2.51.3",
|
"langfuse~=2.51.3",
|
||||||
"langsmith~=0.1.77",
|
"langsmith~=0.1.77",
|
||||||
"mailchimp-transactional~=1.0.50",
|
"mailchimp-transactional~=1.0.50",
|
||||||
@ -163,10 +164,7 @@ storage = [
|
|||||||
############################################################
|
############################################################
|
||||||
# [ Tools ] dependency group
|
# [ Tools ] dependency group
|
||||||
############################################################
|
############################################################
|
||||||
tools = [
|
tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||||
"cloudscraper~=1.2.71",
|
|
||||||
"nltk~=3.9.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# [ VDB ] dependency group
|
# [ VDB ] dependency group
|
||||||
|
14
api/uv.lock
generated
14
api/uv.lock
generated
@ -1,5 +1,4 @@
|
|||||||
version = 1
|
version = 1
|
||||||
revision = 1
|
|
||||||
requires-python = ">=3.11, <3.13"
|
requires-python = ">=3.11, <3.13"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
|
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
|
||||||
@ -1178,6 +1177,7 @@ dependencies = [
|
|||||||
{ name = "gunicorn" },
|
{ name = "gunicorn" },
|
||||||
{ name = "httpx", extra = ["socks"] },
|
{ name = "httpx", extra = ["socks"] },
|
||||||
{ name = "jieba" },
|
{ name = "jieba" },
|
||||||
|
{ name = "json-repair" },
|
||||||
{ name = "langfuse" },
|
{ name = "langfuse" },
|
||||||
{ name = "langsmith" },
|
{ name = "langsmith" },
|
||||||
{ name = "mailchimp-transactional" },
|
{ name = "mailchimp-transactional" },
|
||||||
@ -1346,6 +1346,7 @@ requires-dist = [
|
|||||||
{ name = "gunicorn", specifier = "~=23.0.0" },
|
{ name = "gunicorn", specifier = "~=23.0.0" },
|
||||||
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
|
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
|
||||||
{ name = "jieba", specifier = "==0.42.1" },
|
{ name = "jieba", specifier = "==0.42.1" },
|
||||||
|
{ name = "json-repair", specifier = ">=0.41.1" },
|
||||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||||
{ name = "langsmith", specifier = "~=0.1.77" },
|
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||||
{ name = "mailchimp-transactional", specifier = "~=1.0.50" },
|
{ name = "mailchimp-transactional", specifier = "~=1.0.50" },
|
||||||
@ -2524,6 +2525,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
|
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "json-repair"
|
||||||
|
version = "0.41.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jsonpath-python"
|
name = "jsonpath-python"
|
||||||
version = "1.0.6"
|
version = "1.0.6"
|
||||||
@ -4074,6 +4084,8 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
|
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
|
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
|
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user