From 7a98dab6a4fd01152c543a656b0771d58b600a20 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 5 Nov 2024 09:27:51 +0800 Subject: [PATCH] refactor(parameter_extractor): implement custom error classes (#10260) --- .../workflow/nodes/parameter_extractor/exc.py | 50 ++++++++++++++++ .../parameter_extractor_node.py | 57 ++++++++++++------- 2 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 api/core/workflow/nodes/parameter_extractor/exc.py diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py new file mode 100644 index 0000000000..6511aba185 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -0,0 +1,50 @@ +class ParameterExtractorNodeError(ValueError): + """Base error for ParameterExtractorNode.""" + + +class InvalidModelTypeError(ParameterExtractorNodeError): + """Raised when the model is not a Large Language Model.""" + + +class ModelSchemaNotFoundError(ParameterExtractorNodeError): + """Raised when the model schema is not found.""" + + +class InvalidInvokeResultError(ParameterExtractorNodeError): + """Raised when the invoke result is invalid.""" + + +class InvalidTextContentTypeError(ParameterExtractorNodeError): + """Raised when the text content type is invalid.""" + + +class InvalidNumberOfParametersError(ParameterExtractorNodeError): + """Raised when the number of parameters is invalid.""" + + +class RequiredParameterMissingError(ParameterExtractorNodeError): + """Raised when a required parameter is missing.""" + + +class InvalidSelectValueError(ParameterExtractorNodeError): + """Raised when a select value is invalid.""" + + +class InvalidNumberValueError(ParameterExtractorNodeError): + """Raised when a number value is invalid.""" + + +class InvalidBoolValueError(ParameterExtractorNodeError): + """Raised when a bool value is invalid.""" + + +class InvalidStringValueError(ParameterExtractorNodeError): + """Raised when a string value is invalid.""" + + +class InvalidArrayValueError(ParameterExtractorNodeError): + """Raised when an array value is invalid.""" + + +class InvalidModelModeError(ParameterExtractorNodeError): + """Raised when the model mode is invalid.""" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 49546e9356..b64bde8ac5 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -32,6 +32,21 @@ from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus from .entities import ParameterExtractorNodeData +from .exc import ( + InvalidArrayValueError, + InvalidBoolValueError, + InvalidInvokeResultError, + InvalidModelModeError, + InvalidModelTypeError, + InvalidNumberOfParametersError, + InvalidNumberValueError, + InvalidSelectValueError, + InvalidStringValueError, + InvalidTextContentTypeError, + ModelSchemaNotFoundError, + ParameterExtractorNodeError, + RequiredParameterMissingError, +) from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, @@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema( @@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode): credentials=model_config.credentials, ) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") # fetch memory memory = self._fetch_memory( @@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode): process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) process_data["llm_text"] = text - except Exception as e: + except ParameterExtractorNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, @@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode): try: result = self._validate_result(data=node_data, result=result or {}) - except Exception as e: + except ParameterExtractorNodeError as e: error = str(e) # transform result into standard format @@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode): # handle invoke result if not isinstance(invoke_result, LLMResult): - raise ValueError(f"Invalid invoke result: {invoke_result}") + raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content if not isinstance(text, str): - raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") + raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode): files=files, ) else: - raise ValueError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {model_mode}") def _generate_prompt_engineering_completion_prompt( self, @@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode): Validate result. """ if len(data.parameters) != len(result): - raise ValueError("Invalid number of parameters") + raise InvalidNumberOfParametersError("Invalid number of parameters") for parameter in data.parameters: if parameter.required and parameter.name not in result: - raise ValueError(f"Parameter {parameter.name} is required") + raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: - raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): - raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): - raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") if parameter.type == "string" and not isinstance(result.get(parameter.name), str): - raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") if parameter.type.startswith("array"): parameters = result.get(parameter.name) if not isinstance(parameters, list): - raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in parameters: if nested_type == "number" and not isinstance(item, int | float): - raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") if nested_type == "string" and not isinstance(item, str): - raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") if nested_type == "object" and not isinstance(item, dict): - raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: @@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode): user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode): .replace("}γγγ", "") ) else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _calculate_rest_token( self, @@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)