fix(core): Reorder field_validator and classmethod to fit Pydantic V2. (#5257)

This commit is contained in:
-LAN- 2024-06-17 10:04:28 +08:00 committed by GitHub
parent e95f8fa3dc
commit 5a99aeb864
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 17 additions and 15 deletions

View File

@ -39,7 +39,7 @@ jobs:
- name: Ruff check - name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff check --preview ./api run: poetry run -C api ruff check ./api
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'

View File

@ -77,8 +77,8 @@ class QueueIterationNextEvent(AppQueueEvent):
node_run_index: int node_run_index: int
output: Optional[Any] = None # output for the current iteration output: Optional[Any] = None # output for the current iteration
@classmethod
@field_validator('output', mode='before') @field_validator('output', mode='before')
@classmethod
def set_output(cls, v): def set_output(cls, v):
""" """
Set output Set output

View File

@ -124,6 +124,7 @@ class AssistantPromptMessage(PromptMessage):
function: ToolCallFunction function: ToolCallFunction
@field_validator('id', mode='before') @field_validator('id', mode='before')
@classmethod
def transform_id_to_str(cls, value) -> str: def transform_id_to_str(cls, value) -> str:
if not isinstance(value, str): if not isinstance(value, str):
return str(value) return str(value)

View File

@ -32,8 +32,8 @@ class TwilioAPIWrapper(BaseModel):
must be empty. must be empty.
""" """
@classmethod
@field_validator('client', mode='before') @field_validator('client', mode='before')
@classmethod
def set_validator(cls, values: dict) -> dict: def set_validator(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
try: try:

View File

@ -66,7 +66,7 @@ class BuiltinTool(Tool):
tenant_id=self.runtime.tenant_id, tenant_id=self.runtime.tenant_id,
prompt_messages=prompt_messages prompt_messages=prompt_messages
) )
def summary(self, user_id: str, content: str) -> str: def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens() max_tokens = self.get_max_tokens()

View File

@ -32,8 +32,8 @@ class Tool(BaseModel, ABC):
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@classmethod
@field_validator('parameters', mode='before') @field_validator('parameters', mode='before')
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or [] return v or []

View File

@ -1,7 +1,7 @@
import os import os
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from pydantic import BaseModel, field_validator from pydantic import BaseModel, ValidationInfo, field_validator
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -24,13 +24,13 @@ class HttpRequestNodeData(BaseNodeData):
type: Literal['no-auth', 'api-key'] type: Literal['no-auth', 'api-key']
config: Optional[Config] config: Optional[Config]
@classmethod
@field_validator('config', mode='before') @field_validator('config', mode='before')
def check_config(cls, v, values): @classmethod
def check_config(cls, v: Config, values: ValidationInfo):
""" """
Check config, if type is no-auth, config should be None, otherwise it should be a dict. Check config, if type is no-auth, config should be None, otherwise it should be a dict.
""" """
if values['type'] == 'no-auth': if values.data['type'] == 'no-auth':
return None return None
else: else:
if not v or not isinstance(v, dict): if not v or not isinstance(v, dict):

View File

@ -25,8 +25,8 @@ class ParameterConfig(BaseModel):
description: str description: str
required: bool required: bool
@classmethod
@field_validator('name', mode='before') @field_validator('name', mode='before')
@classmethod
def validate_name(cls, value) -> str: def validate_name(cls, value) -> str:
if not value: if not value:
raise ValueError('Parameter name is required') raise ValueError('Parameter name is required')
@ -45,8 +45,8 @@ class ParameterExtractorNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
reasoning_mode: Literal['function_call', 'prompt'] reasoning_mode: Literal['function_call', 'prompt']
@classmethod
@field_validator('reasoning_mode', mode='before') @field_validator('reasoning_mode', mode='before')
@classmethod
def set_reasoning_mode(cls, v) -> str: def set_reasoning_mode(cls, v) -> str:
return v or 'function_call' return v or 'function_call'

View File

@ -14,9 +14,9 @@ class ToolEntity(BaseModel):
tool_label: str # redundancy tool_label: str # redundancy
tool_configurations: dict[str, Any] tool_configurations: dict[str, Any]
@classmethod
@field_validator('tool_configurations', mode='before') @field_validator('tool_configurations', mode='before')
def validate_tool_configurations(cls, value, values: ValidationInfo) -> dict[str, Any]: @classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict): if not isinstance(value, dict):
raise ValueError('tool_configurations must be a dictionary') raise ValueError('tool_configurations must be a dictionary')
@ -32,8 +32,8 @@ class ToolNodeData(BaseNodeData, ToolEntity):
value: Union[Any, list[str]] value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant'] type: Literal['mixed', 'variable', 'constant']
@classmethod
@field_validator('type', mode='before') @field_validator('type', mode='before')
@classmethod
def check_type(cls, value, validation_info: ValidationInfo): def check_type(cls, value, validation_info: ValidationInfo):
typ = value typ = value
value = validation_info.data.get('value') value = validation_info.data.get('value')

View File

@ -7,6 +7,7 @@ exclude = [
line-length = 120 line-length = 120
[tool.ruff.lint] [tool.ruff.lint]
preview = true
select = [ select = [
"B", # flake8-bugbear rules "B", # flake8-bugbear rules
"F", # pyflakes rules "F", # pyflakes rules

View File

@ -9,7 +9,7 @@ if ! command -v ruff &> /dev/null; then
fi fi
# run ruff linter # run ruff linter
ruff check --fix --preview ./api ruff check --fix ./api
# env files linting relies on `dotenv-linter` in path # env files linting relies on `dotenv-linter` in path
if ! command -v dotenv-linter &> /dev/null; then if ! command -v dotenv-linter &> /dev/null; then