refactor(workflow): introduce specific error handling for LLM nodes (#10221)

This commit is contained in:
-LAN- 2024-11-04 15:22:58 +08:00 committed by GitHub
parent 2adab7f71a
commit 38bca6731c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 12 deletions

View File

@ -0,0 +1,26 @@
class LLMNodeError(ValueError):
"""Base class for LLM Node errors."""
class VariableNotFoundError(LLMNodeError):
"""Raised when a required variable is not found."""
class InvalidContextStructureError(LLMNodeError):
"""Raised when the context structure is invalid."""
class InvalidVariableTypeError(LLMNodeError):
"""Raised when the variable type is invalid."""
class ModelNotExistError(LLMNodeError):
"""Raised when the specified model does not exist."""
class LLMModeRequiredError(LLMNodeError):
"""Raised when LLM mode is required but not provided."""
class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""

View File

@ -56,6 +56,15 @@ from .entities import (
LLMNodeData, LLMNodeData,
ModelConfig, ModelConfig,
) )
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
ModelNotExistError,
NoPromptFoundError,
VariableNotFoundError,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if self.node_data.memory: if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query: if not query:
raise ValueError("Query not found") raise VariableNotFoundError("Query not found")
query = query.text query = query.text
else: else:
query = None query = None
@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event.usage usage = event.usage
finish_reason = event.finish_reason finish_reason = event.finish_reason
break break
except Exception as e: except LLMNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_name = variable_selector.variable variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None: if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found") raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
def parse_dict(input_dict: Mapping[str, Any]) -> str: def parse_dict(input_dict: Mapping[str, Any]) -> str:
""" """
@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None: if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found") raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment): if isinstance(variable, NoneSegment):
inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = ""
inputs[variable_selector.variable] = variable.to_object() inputs[variable_selector.variable] = variable.to_object()
@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors: for variable_selector in query_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None: if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found") raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment): if isinstance(variable, NoneSegment):
continue continue
inputs[variable_selector.variable] = variable.to_object() inputs[variable_selector.variable] = variable.to_object()
@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return variable.value return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment): elif isinstance(variable, NoneSegment | ArrayAnySegment):
return [] return []
raise ValueError(f"Invalid variable type: {type(variable)}") raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData): def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled: if not node_data.context.enabled:
@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
context_str += item + "\n" context_str += item + "\n"
else: else:
if "content" not in item: if "content" not in item:
raise ValueError(f"Invalid context structure: {item}") raise InvalidContextStructureError(f"Invalid context structure: {item}")
context_str += item["content"] + "\n" context_str += item["content"] + "\n"
@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
if provider_model is None: if provider_model is None:
raise ValueError(f"Model {model_name} not exist.") raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE: if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
# get model mode # get model mode
model_mode = node_data_model.mode model_mode = node_data_model.mode
if not model_mode: if not model_mode:
raise ValueError("LLM mode is required.") raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials) model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity( return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name, provider=provider_name,
@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
filtered_prompt_messages.append(prompt_message) filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages: if not filtered_prompt_messages:
raise ValueError( raise NoPromptFoundError(
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()
else: else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {} variable_mapping = {}
for variable_selector in variable_selectors: for variable_selector in variable_selectors: