From 38bca6731c64dfea33fba62a5e71a70504479d8c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:22:58 +0800 Subject: [PATCH] refactor(workflow): introduce specific error handling for LLM nodes (#10221) --- api/core/workflow/nodes/llm/exc.py | 26 +++++++++++++++++++++++ api/core/workflow/nodes/llm/node.py | 33 ++++++++++++++++++----------- 2 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 api/core/workflow/nodes/llm/exc.py diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py new file mode 100644 index 0000000000..f858be2515 --- /dev/null +++ b/api/core/workflow/nodes/llm/exc.py @@ -0,0 +1,26 @@ +class LLMNodeError(ValueError): + """Base class for LLM Node errors.""" + + +class VariableNotFoundError(LLMNodeError): + """Raised when a required variable is not found.""" + + +class InvalidContextStructureError(LLMNodeError): + """Raised when the context structure is invalid.""" + + +class InvalidVariableTypeError(LLMNodeError): + """Raised when the variable type is invalid.""" + + +class ModelNotExistError(LLMNodeError): + """Raised when the specified model does not exist.""" + + +class LLMModeRequiredError(LLMNodeError): + """Raised when LLM mode is required but not provided.""" + + +class NoPromptFoundError(LLMNodeError): + """Raised when no prompt is found in the LLM configuration.""" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index bb9290ddc2..47b0e25d9c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -56,6 +56,15 @@ from .entities import ( LLMNodeData, ModelConfig, ) +from .exc import ( + InvalidContextStructureError, + InvalidVariableTypeError, + LLMModeRequiredError, + LLMNodeError, + ModelNotExistError, + NoPromptFoundError, + VariableNotFoundError, +) if TYPE_CHECKING: from core.file.models import File @@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]): if self.node_data.memory: query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) if not query: - raise ValueError("Query not found") + raise VariableNotFoundError("Query not found") query = query.text else: query = None @@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = event.usage finish_reason = event.finish_reason break - except Exception as e: + except LLMNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: """ @@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() @@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() @@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]): return variable.value elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] - raise ValueError(f"Invalid variable type: {type(variable)}") + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: @@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]): context_str += item + "\n" else: if "content" not in item: - raise ValueError(f"Invalid context structure: {item}") + raise InvalidContextStructureError(f"Invalid context structure: {item}") context_str += item["content"] + "\n" @@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") @@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]): # get model mode model_mode = node_data_model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise LLMModeRequiredError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, @@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]): filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError( + raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) @@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: - raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") + raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: