fix: mypy issues

This commit is contained in:
Yeuoly 2025-01-09 16:53:30 +08:00
parent 76e24d91c0
commit f748d6c7c4
49 changed files with 157 additions and 133 deletions

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[dict, Generator[str | dict, None, None]]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: dict,
args: Mapping[str, Any],
streaming: bool = True,
) -> dict[str, Any] | Generator[str | dict, Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel):
"""
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False):
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
"""
Get custom credentials.
@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel):
model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types = []
model_types: list[ModelType] = []
if model_type:
model_types.append(model_type)
else:
model_types = provider_schema.supported_model_types
model_types = list(provider_schema.supported_model_types)
# Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel):
def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())
def get(self, key, default=None):
def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations.get(key, default)
return self.configurations.get(key, default) # type: ignore
class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ class UploadFileParser:
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64

View File

@ -48,7 +48,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
),
)
answer = cast(str, response.message.content)
@ -101,7 +101,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
@ -110,7 +110,7 @@ class LLMGenerator:
questions = output_parser.parse(cast(str, response.message.content))
except InvokeError:
questions = []
except Exception as e:
except Exception:
logging.exception("Failed to generate suggested questions after answer")
questions = []
@ -150,7 +150,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
@ -200,7 +200,7 @@ class LLMGenerator:
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
except InvokeError as e:
@ -236,7 +236,7 @@ class LLMGenerator:
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
@ -248,7 +248,7 @@ class LLMGenerator:
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
@ -301,7 +301,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)

View File

@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
plugin_model_manager = PluginModelManager()
result = plugin_model_manager.invoke_llm(

View File

@ -285,17 +285,17 @@ class ModelProviderFactory:
}
if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params)
return LargeLanguageModel(**init_params) # type: ignore
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params)
return TextEmbeddingModel(**init_params) # type: ignore
elif model_type == ModelType.RERANK:
return RerankModel(**init_params)
return RerankModel(**init_params) # type: ignore
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params)
return Speech2TextModel(**init_params) # type: ignore
elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params)
return ModerationModel(**init_params) # type: ignore
elif model_type == ModelType.TTS:
return TTSModel(**init_params)
return TTSModel(**init_params) # type: ignore
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""

View File

@ -119,7 +119,7 @@ class BasePluginManager:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json())
return type(**response.json()) # type: ignore
def _request_with_plugin_daemon_response(
self,
@ -140,7 +140,7 @@ class BasePluginManager:
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
@ -171,7 +171,7 @@ class BasePluginManager:
line_data = None
try:
line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data)
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
except Exception:
# TODO modify this when line_data has code and message
if line_data and "error" in line_data:

View File

@ -742,7 +742,7 @@ class ProviderManager:
try:
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials: dict[str, Any] = {}
provider_credentials = {}
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(

View File

@ -601,6 +601,9 @@ class DatasetRetrieval:
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
if retrieve_config.reranking_model is None:
raise ValueError("Reranking model is required for multiple retrieval")
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,

View File

@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
if not text:
return 0
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []
if embedding_model_instance:
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
else:
return GPT2Tokenizer.get_num_tokens(text)
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
_good_splits_lengths = [] # cache the lengths of the splits
s_lens = self._length_function(splits)
for s, s_len in zip(splits, s_lens):
s_len = self._length_function(s)
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)

View File

@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
separator_len = self._length_function([separator])[0]
docs = []
current_doc: list[str] = []
@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
):
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
total -= self._length_function([current_doc[0]])[0] + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
raise ValueError(
"Could not import transformers python package. Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
@classmethod
def from_tiktoken_encoder(
@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs)
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""

View File

@ -71,13 +71,13 @@ class Tool(ABC):
if isinstance(result, ToolInvokeMessage):
def single_generator():
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
yield result
return single_generator()
elif isinstance(result, list):
def generator():
def generator() -> Generator[ToolInvokeMessage, None, None]:
yield from result
return generator()

View File

@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None:
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
@property
def need_credentials(self) -> bool:

View File

@ -1,6 +1,7 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
pass

View File

@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
timezone = None
time_format = "%Y-%m-%d %H:%M:%S"
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone)
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
if not timestamp:
yield self.create_text_message(f"Invalid localtime: {localtime}")
return
@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool):
if isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
local_time = datetime.strptime(localtime, time_format)
localtime = local_tz.localize(local_time)
timestamp = int(localtime.timestamp())
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
return timestamp
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool):
"""
Convert timestamp to localtime
"""
timestamp = tool_parameters.get("timestamp")
timestamp: int = tool_parameters.get("timestamp", 0)
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
if not timezone:
timezone = None

View File

@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool):
current_time = tool_parameters.get("current_time")
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
target_time = self.timezone_convert(current_time, current_timezone, target_timezone)
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
if not target_time:
yield self.create_text_message(
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"

View File

@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
pass

View File

@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []
@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = [
ProviderConfig(
name="auth_type",

View File

@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
):
raise ToolProviderCredentialValidationError("Invalid credentials")
def get_tool(self, tool_name: str) -> PluginTool:
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
"""
return tool with given name
"""
@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_tools(self) -> list[PluginTool]:
def get_tools(self) -> list[PluginTool]: # type: ignore
"""
get all tools
"""

View File

@ -59,7 +59,12 @@ class PluginTool(Tool):
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
"""
get the runtime parameters
"""
@ -76,6 +81,9 @@ class PluginTool(Tool):
provider=self.entity.identity.provider,
tool=self.entity.identity.name,
credentials=self.runtime.credentials,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
return self.runtime_parameters

View File

@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL
@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers = {}
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@ -203,7 +203,7 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
else:
builtin_provider: BuiltinToolProvider | None = (
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.first()
@ -270,9 +270,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: Optional[list[Tool]] = controller.get_tools(
user_id="", tenant_id=workflow_provider.tenant_id
)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
@ -747,18 +745,21 @@ class ToolManager:
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder(
{
"schema_type": provider_obj.schema_type,
"schema": provider_obj.schema,
"tools": provider_obj.tools,
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
return cast(
dict,
jsonable_encoder(
{
"schema_type": provider_obj.schema_type,
"schema": provider_obj.schema,
"tools": provider_obj.tools,
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
),
)
@classmethod
@ -795,7 +796,8 @@ class ToolManager:
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(workflow_provider.icon)
icon: dict = json.loads(workflow_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@ -811,7 +813,8 @@ class ToolManager:
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
return json.loads(api_provider.icon)
icon: dict = json.loads(api_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from services.external_knowledge_service import ExternalDatasetService

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool):
return tools
def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
return [
ToolParameter(
name="query",
@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool):
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool
"""

View File

@ -91,7 +91,7 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file")
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None

View File

@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> bool:
):
"""
check is synced

View File

@ -6,7 +6,6 @@ from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None)
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
user = db_provider.user
@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools
def get_tool(self, tool_name: str) -> Optional[Tool]:
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore
"""
get tool by name

View File

@ -106,9 +106,9 @@ class WorkflowTool(Tool):
if outputs is None:
outputs = {}
else:
outputs, files = self._extract_files(outputs)
outputs, files = self._extract_files(outputs) # type: ignore
for file in files:
yield self.create_file_message(file)
yield self.create_file_message(file) # type: ignore
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
@ -217,7 +217,7 @@ class WorkflowTool(Tool):
:param result: the result
:return: the result, files
"""
files = []
files: list[File] = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
@ -238,4 +238,5 @@ class WorkflowTool(Tool):
files.append(file)
result[key] = value
return result, files

View File

@ -27,7 +27,7 @@ class AgentNode(ToolNode):
Agent Node
"""
_node_data_cls = AgentNodeData
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT
def _run(self) -> Generator:
@ -125,7 +125,7 @@ class AgentNode(ToolNode):
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@ -214,7 +214,7 @@ class AgentNode(ToolNode):
:return:
"""
node_data = cast(AgentNodeData, node_data)
result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name]
if input.type == "mixed":

View File

@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]):
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
stop=stop,
stop=list(stop or []),
stream=True,
user=self.user_id,
)

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
variables: dict[str, Any] = {}

View File

@ -284,8 +284,6 @@ class WorkflowEntry:
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
)
# run node

View File

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields
from flask_restful import fields # type: ignore
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

View File

@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
return g._login_user
return g._login_user # type: ignore
return None

View File

@ -1,7 +1,7 @@
import enum
import json
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
@ -56,7 +56,7 @@ class Account(UserMixin, Base):
if ta:
tenant.current_role = ta.role
else:
tenant = None
tenant = None # type: ignore
self._current_tenant = tenant

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, cast
import sqlalchemy as sa
from deprecated import deprecated
@ -48,7 +48,7 @@ class BuiltinToolProvider(Base):
@property
def credentials(self) -> dict:
return json.loads(self.encrypted_credentials)
return cast(dict, json.loads(self.encrypted_credentials))
class ApiToolProvider(Base):
@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
# id of the tool provider
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
# who published this tool
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False)
@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base):
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
@property
def app(self) -> App:
return db.session.query(App).filter(App.id == self.app_id).first()
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)

View File

@ -23,7 +23,7 @@ class AgentService:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
conversation: Conversation = (
conversation: Conversation | None = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,

View File

@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel):
model_config = ConfigDict(protected_namespaces=())
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""

View File

@ -173,9 +173,8 @@ class PluginMigration:
"""
Extract model tables.
NOTE: rename google to gemini
"""
models = []
models: list[str] = []
table_pairs = [
("providers", "provider_name"),
("provider_models", "provider_name"),

View File

@ -439,7 +439,7 @@ class ApiToolManageService:
tenant_id=tenant_id,
)
)
result = runtime_tool.validate_credentials(credentials, parameters)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return {"error": str(e)}

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, Union
from typing import Optional, Union, cast
from yarl import URL
@ -44,7 +44,7 @@ class ToolTransformService:
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
if isinstance(icon, str):
return json.loads(icon)
return cast(dict, json.loads(icon))
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -1,7 +1,7 @@
import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from sqlalchemy import or_
@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.model import App
from models.tools import WorkflowToolProvider
@ -187,7 +188,7 @@ class WorkflowToolManageService:
"""
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
tools: Sequence[WorkflowToolProviderController] = []
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
@ -264,7 +265,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
"""
Get a workflow tool.
:db_tool: the database tool
@ -285,8 +286,8 @@ class WorkflowToolManageService:
raise ValueError("Workflow not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {db_tool.id} not found")
return {
@ -325,8 +326,8 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")
return [

View File

@ -67,7 +67,7 @@ def batch_create_segment_to_index_task(
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
segment_hash = helper.generate_text_hash(content) # type: ignore
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id)