mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-29 21:32:02 +08:00
refactor(workflow): introduce specific error handling for LLM nodes (#10221)
This commit is contained in:
parent
2adab7f71a
commit
38bca6731c
26
api/core/workflow/nodes/llm/exc.py
Normal file
26
api/core/workflow/nodes/llm/exc.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
class LLMNodeError(ValueError):
|
||||||
|
"""Base class for LLM Node errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class VariableNotFoundError(LLMNodeError):
|
||||||
|
"""Raised when a required variable is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidContextStructureError(LLMNodeError):
|
||||||
|
"""Raised when the context structure is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidVariableTypeError(LLMNodeError):
|
||||||
|
"""Raised when the variable type is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotExistError(LLMNodeError):
|
||||||
|
"""Raised when the specified model does not exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModeRequiredError(LLMNodeError):
|
||||||
|
"""Raised when LLM mode is required but not provided."""
|
||||||
|
|
||||||
|
|
||||||
|
class NoPromptFoundError(LLMNodeError):
|
||||||
|
"""Raised when no prompt is found in the LLM configuration."""
|
@ -56,6 +56,15 @@ from .entities import (
|
|||||||
LLMNodeData,
|
LLMNodeData,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
)
|
)
|
||||||
|
from .exc import (
|
||||||
|
InvalidContextStructureError,
|
||||||
|
InvalidVariableTypeError,
|
||||||
|
LLMModeRequiredError,
|
||||||
|
LLMNodeError,
|
||||||
|
ModelNotExistError,
|
||||||
|
NoPromptFoundError,
|
||||||
|
VariableNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
if self.node_data.memory:
|
if self.node_data.memory:
|
||||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||||
if not query:
|
if not query:
|
||||||
raise ValueError("Query not found")
|
raise VariableNotFoundError("Query not found")
|
||||||
query = query.text
|
query = query.text
|
||||||
else:
|
else:
|
||||||
query = None
|
query = None
|
||||||
@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
usage = event.usage
|
usage = event.usage
|
||||||
finish_reason = event.finish_reason
|
finish_reason = event.finish_reason
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except LLMNodeError as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
variable_name = variable_selector.variable
|
variable_name = variable_selector.variable
|
||||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||||
|
|
||||||
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
for variable_selector in variable_selectors:
|
for variable_selector in variable_selectors:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||||
if isinstance(variable, NoneSegment):
|
if isinstance(variable, NoneSegment):
|
||||||
inputs[variable_selector.variable] = ""
|
inputs[variable_selector.variable] = ""
|
||||||
inputs[variable_selector.variable] = variable.to_object()
|
inputs[variable_selector.variable] = variable.to_object()
|
||||||
@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
for variable_selector in query_variable_selectors:
|
for variable_selector in query_variable_selectors:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||||
if isinstance(variable, NoneSegment):
|
if isinstance(variable, NoneSegment):
|
||||||
continue
|
continue
|
||||||
inputs[variable_selector.variable] = variable.to_object()
|
inputs[variable_selector.variable] = variable.to_object()
|
||||||
@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
return variable.value
|
return variable.value
|
||||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||||
return []
|
return []
|
||||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||||
|
|
||||||
def _fetch_context(self, node_data: LLMNodeData):
|
def _fetch_context(self, node_data: LLMNodeData):
|
||||||
if not node_data.context.enabled:
|
if not node_data.context.enabled:
|
||||||
@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
context_str += item + "\n"
|
context_str += item + "\n"
|
||||||
else:
|
else:
|
||||||
if "content" not in item:
|
if "content" not in item:
|
||||||
raise ValueError(f"Invalid context structure: {item}")
|
raise InvalidContextStructureError(f"Invalid context structure: {item}")
|
||||||
|
|
||||||
context_str += item["content"] + "\n"
|
context_str += item["content"] + "\n"
|
||||||
|
|
||||||
@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if provider_model is None:
|
if provider_model is None:
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
# get model mode
|
# get model mode
|
||||||
model_mode = node_data_model.mode
|
model_mode = node_data_model.mode
|
||||||
if not model_mode:
|
if not model_mode:
|
||||||
raise ValueError("LLM mode is required.")
|
raise LLMModeRequiredError("LLM mode is required.")
|
||||||
|
|
||||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||||
|
|
||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
return model_instance, ModelConfigWithCredentialsEntity(
|
return model_instance, ModelConfigWithCredentialsEntity(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
filtered_prompt_messages.append(prompt_message)
|
filtered_prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
if not filtered_prompt_messages:
|
if not filtered_prompt_messages:
|
||||||
raise ValueError(
|
raise NoPromptFoundError(
|
||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
|
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||||
|
|
||||||
variable_mapping = {}
|
variable_mapping = {}
|
||||||
for variable_selector in variable_selectors:
|
for variable_selector in variable_selectors:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user