From f748d6c7c41d8dc928fe6c0d1090a0cd6b83973d Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 9 Jan 2025 16:53:30 +0800 Subject: [PATCH] fix: mypy issues --- .../agent_chat/generate_response_converter.py | 5 ++- .../base_app_generate_response_converter.py | 2 +- .../apps/chat/generate_response_converter.py | 5 ++- .../completion/generate_response_converter.py | 5 ++- api/core/app/apps/workflow/app_generator.py | 6 +-- .../workflow/generate_response_converter.py | 5 ++- api/core/entities/provider_configuration.py | 10 ++--- api/core/file/upload_file_parser.py | 2 +- api/core/llm_generator/llm_generator.py | 16 +++---- .../__base/large_language_model.py | 2 + .../model_providers/model_provider_factory.py | 12 +++--- api/core/plugin/manager/base.py | 6 +-- api/core/provider_manager.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 3 ++ api/core/rag/splitter/fixed_text_splitter.py | 11 +++-- api/core/rag/splitter/text_splitter.py | 10 +++-- api/core/tools/__base/tool.py | 4 +- api/core/tools/builtin_tool/provider.py | 4 +- .../builtin_tool/providers/audio/audio.py | 3 +- .../time/tools/localtime_to_timestamp.py | 6 +-- .../time/tools/timestamp_to_localtime.py | 2 +- .../time/tools/timezone_conversion.py | 2 +- .../providers/webscraper/webscraper.py | 2 +- api/core/tools/custom_tool/provider.py | 2 +- api/core/tools/plugin_tool/provider.py | 4 +- api/core/tools/plugin_tool/tool.py | 10 ++++- api/core/tools/tool_manager.py | 43 ++++++++++--------- .../dataset_retriever_tool.py | 2 +- .../tools/utils/dataset_retriever_tool.py | 18 ++++++-- api/core/tools/utils/message_transformer.py | 2 +- .../utils/workflow_configuration_sync.py | 2 +- api/core/tools/workflow_as_tool/provider.py | 5 +-- api/core/tools/workflow_as_tool/tool.py | 7 +-- api/core/workflow/nodes/agent/agent_node.py | 6 +-- api/core/workflow/nodes/llm/node.py | 4 +- api/core/workflow/nodes/tool/tool_node.py | 4 +- api/core/workflow/workflow_entry.py | 2 - api/libs/helper.py | 2 +- api/libs/login.py | 2 +- api/models/account.py | 4 +- api/models/model.py | 2 +- api/models/tools.py | 12 +----- api/services/agent_service.py | 2 +- .../entities/model_provider_entities.py | 2 +- api/services/plugin/plugin_migration.py | 3 +- .../tools/api_tools_manage_service.py | 2 +- api/services/tools/tools_transform_service.py | 4 +- .../tools/workflow_tools_manage_service.py | 17 ++++---- .../batch_create_segment_to_index_task.py | 2 +- 49 files changed, 157 insertions(+), 133 deletions(-) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 82ec33b269..0eea135167 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -3,6 +3,7 @@ from typing import cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( + AppStreamResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream full response. @@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 49ddb2c83a..29c1ad598e 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC): @abstractmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[str, None, None]: + ) -> Generator[dict | str, None, None]: raise NotImplementedError @classmethod diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index bfaefeb8cb..13a6be167c 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -3,6 +3,7 @@ from typing import cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( + AppStreamResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream full response. @@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 89dda03da1..c2b78e8176 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -3,6 +3,7 @@ from typing import cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( + AppStreamResponse, CompletionAppBlockingResponse, CompletionAppStreamResponse, ErrorStreamResponse, @@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream full response. @@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 7c502a143b..f13cb53009 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, - ) -> Union[dict, Generator[str | dict, None, None]]: + ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow: Workflow, node_id: str, user: Account | EndUser, - args: dict, + args: Mapping[str, Any], streaming: bool = True, - ) -> dict[str, Any] | Generator[str | dict, Any, None]: + ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: """ Generate App response. diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index cba7dc96fb..10ec73a7d2 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -3,6 +3,7 @@ from typing import cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( + AppStreamResponse, ErrorStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, @@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream full response. @@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict | str, None, None]: """ Convert stream simple response. diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 0636234593..6ed065d925 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel): """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - def get_custom_credentials(self, obfuscated: bool = False): + def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: """ Get custom credentials. @@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel): model_provider_factory = ModelProviderFactory(self.tenant_id) provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) - model_types = [] + model_types: list[ModelType] = [] if model_type: model_types.append(model_type) else: - model_types = provider_schema.supported_model_types + model_types = list(provider_schema.supported_model_types) # Group model settings by model type and model model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) @@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel): def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) - def get(self, key, default=None): + def get(self, key, default=None) -> ProviderConfiguration | None: if "/" not in key: key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" - return self.configurations.get(key, default) + return self.configurations.get(key, default) # type: ignore class ProviderModelBundle(BaseModel): diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index bcbb833b7f..062a0b6d22 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -20,7 +20,7 @@ class UploadFileParser: if upload_file.extension not in IMAGE_EXTENSIONS: return None - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: + if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url: return cls.get_signed_temp_image_url(upload_file.id) else: # get image file base64 diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 9fe3f68f2a..75687f9ae3 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -48,7 +48,7 @@ class LLMGenerator: response = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False ), ) answer = cast(str, response.message.content) @@ -101,7 +101,7 @@ class LLMGenerator: response = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=prompt_messages, + prompt_messages=list(prompt_messages), model_parameters={"max_tokens": 256, "temperature": 0}, stream=False, ), @@ -110,7 +110,7 @@ class LLMGenerator: questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] - except Exception as e: + except Exception: logging.exception("Failed to generate suggested questions after answer") questions = [] @@ -150,7 +150,7 @@ class LLMGenerator: response = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ), ) @@ -200,7 +200,7 @@ class LLMGenerator: prompt_content = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ), ) except InvokeError as e: @@ -236,7 +236,7 @@ class LLMGenerator: parameter_content = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False ), ) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) @@ -248,7 +248,7 @@ class LLMGenerator: statement_content = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False ), ) rule_config["opening_statement"] = cast(str, statement_content.message.content) @@ -301,7 +301,7 @@ class LLMGenerator: response = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False ), ) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index c93ab4f61e..e833322f27 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) + result: Union[LLMResult, Generator[LLMResultChunk, None, None]] + try: plugin_model_manager = PluginModelManager() result = plugin_model_manager.invoke_llm( diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 2bba3847b1..23596558db 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -285,17 +285,17 @@ class ModelProviderFactory: } if model_type == ModelType.LLM: - return LargeLanguageModel(**init_params) + return LargeLanguageModel(**init_params) # type: ignore elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(**init_params) + return TextEmbeddingModel(**init_params) # type: ignore elif model_type == ModelType.RERANK: - return RerankModel(**init_params) + return RerankModel(**init_params) # type: ignore elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(**init_params) + return Speech2TextModel(**init_params) # type: ignore elif model_type == ModelType.MODERATION: - return ModerationModel(**init_params) + return ModerationModel(**init_params) # type: ignore elif model_type == ModelType.TTS: - return TTSModel(**init_params) + return TTSModel(**init_params) # type: ignore def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index 0540cb2d44..ddfc42b974 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -119,7 +119,7 @@ class BasePluginManager: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type(**response.json()) + return type(**response.json()) # type: ignore def _request_with_plugin_daemon_response( self, @@ -140,7 +140,7 @@ class BasePluginManager: if transformer: json_response = transformer(json_response) - rep = PluginDaemonBasicResponse[type](**json_response) + rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore if rep.code != 0: try: error = PluginDaemonError(**json.loads(rep.message)) @@ -171,7 +171,7 @@ class BasePluginManager: line_data = None try: line_data = json.loads(line) - rep = PluginDaemonBasicResponse[type](**line_data) + rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore except Exception: # TODO modify this when line_data has code and message if line_data and "error" in line_data: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 92ba1133a2..ddc2137f32 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -742,7 +742,7 @@ class ProviderManager: try: provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) except JSONDecodeError: - provider_credentials: dict[str, Any] = {} + provider_credentials = {} # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 290d9e6e61..2193d8d3fd 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -601,6 +601,9 @@ class DatasetRetrieval: elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool + if retrieve_config.reranking_model is None: + raise ValueError("Reranking model is required for multiple retrieval") + tool = DatasetMultiRetrieverTool.from_dataset( dataset_ids=[dataset.id for dataset in available_datasets], tenant_id=tenant_id, diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 91fb033c49..2d99cce513 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): - def _token_encoder(text: str) -> int: - if not text: - return 0 + def _token_encoder(texts: list[str]) -> list[int]: + if not texts: + return [] if embedding_model_instance: - return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) + return embedding_model_instance.get_text_embedding_num_tokens(texts=texts) else: - return GPT2Tokenizer.get_num_tokens(text) + return [GPT2Tokenizer.get_num_tokens(text) for text in texts] if issubclass(cls, TokenTextSplitter): extra_kwargs = { @@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) _good_splits_lengths = [] # cache the lengths of the splits s_lens = self._length_function(splits) for s, s_len in zip(splits, s_lens): - s_len = self._length_function(s) if s_len < self._chunk_size: _good_splits.append(s) _good_splits_lengths.append(s_len) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 72c4700d5c..34b4056cf5 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. - separator_len = self._length_function(separator) + separator_len = self._length_function([separator])[0] docs = [] current_doc: list[str] = [] @@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): while total > self._chunk_overlap or ( total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): - total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) + total -= self._length_function([current_doc[0]])[0] + ( + separator_len if len(current_doc) > 1 else 0 + ) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) @@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): raise ValueError( "Could not import transformers python package. Please install it with `pip install transformers`." ) - return cls(length_function=_huggingface_tokenizer_length, **kwargs) + return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) @classmethod def from_tiktoken_encoder( @@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): } kwargs = {**kwargs, **extra_kwargs} - return cls(length_function=_tiktoken_encoder, **kwargs) + return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs) def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 255060ef3c..63937f5f76 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -71,13 +71,13 @@ class Tool(ABC): if isinstance(result, ToolInvokeMessage): - def single_generator(): + def single_generator() -> Generator[ToolInvokeMessage, None, None]: yield result return single_generator() elif isinstance(result, list): - def generator(): + def generator() -> Generator[ToolInvokeMessage, None, None]: yield from result return generator() diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 78949f8b1a..e776258527 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController): """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> BuiltinTool | None: + def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) + return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore @property def need_credentials(self) -> bool: diff --git a/api/core/tools/builtin_tool/providers/audio/audio.py b/api/core/tools/builtin_tool/providers/audio/audio.py index 116279ad20..e8cfba6138 100644 --- a/api/core/tools/builtin_tool/providers/audio/audio.py +++ b/api/core/tools/builtin_tool/providers/audio/audio.py @@ -1,6 +1,7 @@ +from typing import Any from core.tools.builtin_tool.provider import BuiltinToolProviderController class AudioToolProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 15483d9768..1639dd687f 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool): timezone = None time_format = "%Y-%m-%d %H:%M:%S" - timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) + timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore if not timestamp: yield self.create_text_message(f"Invalid localtime: {localtime}") return @@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool): if isinstance(local_tz, str): local_tz = pytz.timezone(local_tz) local_time = datetime.strptime(localtime, time_format) - localtime = local_tz.localize(local_time) - timestamp = int(localtime.timestamp()) + localtime = local_tz.localize(local_time) # type: ignore + timestamp = int(localtime.timestamp()) # type: ignore return timestamp except Exception as e: raise ToolInvokeError(str(e)) diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index d9ba259679..0ef6331530 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool): """ Convert timestamp to localtime """ - timestamp = tool_parameters.get("timestamp") + timestamp: int = tool_parameters.get("timestamp", 0) timezone = tool_parameters.get("timezone", "Asia/Shanghai") if not timezone: timezone = None diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index 3a091f8e70..796c38b697 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool): current_time = tool_parameters.get("current_time") current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai") target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo") - target_time = self.timezone_convert(current_time, current_timezone, target_timezone) + target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore if not target_time: yield self.create_text_message( f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" diff --git a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py index 9d62fb5fcb..bf2199518e 100644 --- a/api/core/tools/builtin_tool/providers/webscraper/webscraper.py +++ b/api/core/tools/builtin_tool/providers/webscraper/webscraper.py @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController class WebscraperProvider(BuiltinToolProviderController): - def _validate_credentials(self, credentials: dict[str, Any]) -> None: + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: pass diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index fd7873b083..7133535313 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = [] @classmethod - def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType): + def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = [ ProviderConfig( name="auth_type", diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index 6a3c701a59..875072cb3e 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController): ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_tool(self, tool_name: str) -> PluginTool: + def get_tool(self, tool_name: str) -> PluginTool: # type: ignore """ return tool with given name """ @@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController): plugin_unique_identifier=self.plugin_unique_identifier, ) - def get_tools(self) -> list[PluginTool]: + def get_tools(self) -> list[PluginTool]: # type: ignore """ get all tools """ diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 8c6dd8894b..f31a9a0d3e 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -59,7 +59,12 @@ class PluginTool(Tool): plugin_unique_identifier=self.plugin_unique_identifier, ) - def get_runtime_parameters(self) -> list[ToolParameter]: + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: """ get the runtime parameters """ @@ -76,6 +81,9 @@ class PluginTool(Tool): provider=self.entity.identity.provider, tool=self.entity.identity.name, credentials=self.runtime.credentials, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, ) return self.runtime_parameters diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 8b7e733116..df99c82d2b 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ import mimetypes from collections.abc import Generator from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) class ToolManager: _builtin_provider_lock = Lock() - _hardcoded_providers = {} + _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @@ -203,7 +203,7 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") else: - builtin_provider: BuiltinToolProvider | None = ( + builtin_provider = ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .first() @@ -270,9 +270,7 @@ class ToolManager: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - controller_tools: Optional[list[Tool]] = controller.get_tools( - user_id="", tenant_id=workflow_provider.tenant_id - ) + controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) if controller_tools is None or len(controller_tools) == 0: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") @@ -747,18 +745,21 @@ class ToolManager: # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder( - { - "schema_type": provider_obj.schema_type, - "schema": provider_obj.schema, - "tools": provider_obj.tools, - "icon": icon, - "description": provider_obj.description, - "credentials": masked_credentials, - "privacy_policy": provider_obj.privacy_policy, - "custom_disclaimer": provider_obj.custom_disclaimer, - "labels": labels, - } + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_obj.schema_type, + "schema": provider_obj.schema, + "tools": provider_obj.tools, + "icon": icon, + "description": provider_obj.description, + "credentials": masked_credentials, + "privacy_policy": provider_obj.privacy_policy, + "custom_disclaimer": provider_obj.custom_disclaimer, + "labels": labels, + } + ), ) @classmethod @@ -795,7 +796,8 @@ class ToolManager: if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return json.loads(workflow_provider.icon) + icon: dict = json.loads(workflow_provider.icon) + return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -811,7 +813,8 @@ class ToolManager: if api_provider is None: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - return json.loads(api_provider.icon) + icon: dict = json.loads(api_provider.icon) + return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index b382016473..80de31ce20 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from services.external_knowledge_service import ExternalDatasetService diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index ca03b1dc94..b73dec4ebc 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, Optional from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool): return tools - def get_runtime_parameters(self) -> list[ToolParameter]: + def get_runtime_parameters( + self, + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> list[ToolParameter]: return [ ToolParameter( name="query", @@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool): def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: Optional[str] = None, + app_id: Optional[str] = None, + message_id: Optional[str] = None, + ) -> Generator[ToolInvokeMessage, None, None]: """ invoke dataset retriever tool """ diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index dad3a651c0..a2f85f8d7e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -91,7 +91,7 @@ class ToolFileMessageTransformer: ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} - file = meta.get("file") + file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: assert file.related_id is not None diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 7e35dc7514..d16d6fc576 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils: @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] - ) -> bool: + ): """ check is synced diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 15e339fda0..4777a019e4 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -6,7 +6,6 @@ from pydantic import Field from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.plugin.entities.parameters import PluginParameterOption -from core.tools.__base.tool import Tool from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController): variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: - return next(filter(lambda x: x.variable == variable_name, variables), None) + return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore user = db_provider.user @@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools - def get_tool(self, tool_name: str) -> Optional[Tool]: + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore """ get tool by name diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 72c8359a00..3d80378bc7 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -106,9 +106,9 @@ class WorkflowTool(Tool): if outputs is None: outputs = {} else: - outputs, files = self._extract_files(outputs) + outputs, files = self._extract_files(outputs) # type: ignore for file in files: - yield self.create_file_message(file) + yield self.create_file_message(file) # type: ignore yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) @@ -217,7 +217,7 @@ class WorkflowTool(Tool): :param result: the result :return: the result, files """ - files = [] + files: list[File] = [] result = {} for key, value in outputs.items(): if isinstance(value, list): @@ -238,4 +238,5 @@ class WorkflowTool(Tool): files.append(file) result[key] = value + return result, files diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 38f49d80dc..dabaacb31e 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -27,7 +27,7 @@ class AgentNode(ToolNode): Agent Node """ - _node_data_cls = AgentNodeData + _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT def _run(self) -> Generator: @@ -125,7 +125,7 @@ class AgentNode(ToolNode): """ agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.agent_parameters: parameter = agent_parameters_dictionary.get(parameter_name) if not parameter: @@ -214,7 +214,7 @@ class AgentNode(ToolNode): :return: """ node_data = cast(AgentNodeData, node_data) - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.agent_parameters: input = node_data.agent_parameters[parameter_name] if input.type == "mixed": diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 6909b30c9e..70a711056a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]): db.session.close() invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, + prompt_messages=list(prompt_messages), model_parameters=node_data_model.completion_params, - stop=stop, + stop=list(stop or []), stream=True, user=self.user_id, ) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 3adce10932..4536b8ffb0 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]): json: list[dict] = [] agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {} + agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} variables: dict[str, Any] = {} diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index f2ed9f7eda..5a7d5373c1 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -284,8 +284,6 @@ class WorkflowEntry: user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=tenant_id, - node_type=node_type, - node_data=node_instance.node_data, ) # run node diff --git a/api/libs/helper.py b/api/libs/helper.py index 4047485e04..e1553094d2 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restful import fields # type: ignore from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator diff --git a/api/libs/login.py b/api/libs/login.py index 174640d986..b128c53c62 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None: if "_login_user" not in g: current_app.login_manager._load_user() # type: ignore - return g._login_user + return g._login_user # type: ignore return None diff --git a/api/models/account.py b/api/models/account.py index 941dd54687..bac1ec1c2e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,7 +1,7 @@ import enum import json -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import func from sqlalchemy.orm import Mapped, mapped_column @@ -56,7 +56,7 @@ class Account(UserMixin, Base): if ta: tenant.current_role = ta.role else: - tenant = None + tenant = None # type: ignore self._current_tenant = tenant diff --git a/api/models/model.py b/api/models/model.py index 462fbb672e..482db7045c 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy.orm import Mapped, Session, mapped_column diff --git a/api/models/tools.py b/api/models/tools.py index e19079301b..b1e321ed85 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, cast import sqlalchemy as sa from deprecated import deprecated @@ -48,7 +48,7 @@ class BuiltinToolProvider(Base): @property def credentials(self) -> dict: - return json.loads(self.encrypted_credentials) + return cast(dict, json.loads(self.encrypted_credentials)) class ApiToolProvider(Base): @@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base): db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - # id of the tool provider - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) # who published this tool - user_id = db.Column(StringUUID, nullable=False) - # description of the tool, stored in i18n format, for human description = db.Column(db.Text, nullable=False) # llm_description of the tool, for LLM llm_description = db.Column(db.Text, nullable=False) @@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base): def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) - @property - def app(self) -> App: - return db.session.query(App).filter(App.id == self.app_id).first() - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) user_id: Mapped[str] = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 3ee23f11a7..0ff144052f 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -23,7 +23,7 @@ class AgentService: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - conversation: Conversation = ( + conversation: Conversation | None = ( db.session.query(Conversation) .filter( Conversation.id == conversation_id, diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 79226ffa52..bc385b2e22 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel): model_config = ConfigDict(protected_namespaces=()) -class ModelWithProviderEntityResponse(ModelWithProviderEntity): +class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity): """ Model with provider entity. """ diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 04bdab11d9..fd1bef98b7 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -173,9 +173,8 @@ class PluginMigration: """ Extract model tables. - NOTE: rename google to gemini """ - models = [] + models: list[str] = [] table_pairs = [ ("providers", "provider_name"), ("provider_models", "provider_name"), diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index f1156feafb..6f848d49c4 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -439,7 +439,7 @@ class ApiToolManageService: tenant_id=tenant_id, ) ) - result = runtime_tool.validate_credentials(credentials, parameters) + result = tool.validate_credentials(credentials, parameters) except Exception as e: return {"error": str(e)} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index c4b9db69ec..83a42ddfcb 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union +from typing import Optional, Union, cast from yarl import URL @@ -44,7 +44,7 @@ class ToolTransformService: elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: if isinstance(icon, str): - return json.loads(icon) + return cast(dict, json.loads(icon)) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index dc7d4a858c..e486ed7b8c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,7 +1,7 @@ import json -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from sqlalchemy import or_ @@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.model import App from models.tools import WorkflowToolProvider @@ -187,7 +188,7 @@ class WorkflowToolManageService: """ db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() - tools: Sequence[WorkflowToolProviderController] = [] + tools: list[WorkflowToolProviderController] = [] for provider in db_tools: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) @@ -264,7 +265,7 @@ class WorkflowToolManageService: return cls._get_workflow_tool(tenant_id, db_tool) @classmethod - def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): + def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict: """ Get a workflow tool. :db_tool: the database tool @@ -285,8 +286,8 @@ class WorkflowToolManageService: raise ValueError("Workflow not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) - to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id) - if to_user_tool is None or len(to_user_tool) == 0: + workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) + if len(workflow_tools) == 0: raise ValueError(f"Tool {db_tool.id} not found") return { @@ -325,8 +326,8 @@ class WorkflowToolManageService: raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) - to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id) - if to_user_tool is None or len(to_user_tool) == 0: + workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) + if len(workflow_tools) == 0: raise ValueError(f"Tool {workflow_tool_id} not found") return [ diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 3238842307..b370d49047 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -67,7 +67,7 @@ def batch_create_segment_to_index_task( for segment, tokens in zip(content, tokens_list): content = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) + segment_hash = helper.generate_text_hash(content) # type: ignore max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id)