fix: mypy issues

This commit is contained in:
Yeuoly 2025-01-09 16:53:30 +08:00
parent 76e24d91c0
commit f748d6c7c4
49 changed files with 157 additions and 133 deletions

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod @abstractmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]: ) -> Generator[dict | str, None, None]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
CompletionAppStreamResponse, CompletionAppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Union[dict, Generator[str | dict, None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow, workflow: Workflow,
node_id: str, node_id: str,
user: Account | EndUser, user: Account | EndUser,
args: dict, args: Mapping[str, Any],
streaming: bool = True, streaming: bool = True,
) -> dict[str, Any] | Generator[str | dict, Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
""" """
Generate App response. Generate App response.

View File

@ -3,6 +3,7 @@ from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse, ErrorStreamResponse,
NodeFinishStreamResponse, NodeFinishStreamResponse,
NodeStartStreamResponse, NodeStartStreamResponse,
@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream full response. Convert stream full response.
@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]: ) -> Generator[dict | str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.

View File

@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel):
""" """
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
def get_custom_credentials(self, obfuscated: bool = False): def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
""" """
Get custom credentials. Get custom credentials.
@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel):
model_provider_factory = ModelProviderFactory(self.tenant_id) model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
model_types = [] model_types: list[ModelType] = []
if model_type: if model_type:
model_types.append(model_type) model_types.append(model_type)
else: else:
model_types = provider_schema.supported_model_types model_types = list(provider_schema.supported_model_types)
# Group model settings by model type and model # Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel):
def values(self) -> Iterator[ProviderConfiguration]: def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values()) return iter(self.configurations.values())
def get(self, key, default=None): def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key: if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
return self.configurations.get(key, default) return self.configurations.get(key, default) # type: ignore
class ProviderModelBundle(BaseModel): class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ class UploadFileParser:
if upload_file.extension not in IMAGE_EXTENSIONS: if upload_file.extension not in IMAGE_EXTENSIONS:
return None return None
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id) return cls.get_signed_temp_image_url(upload_file.id)
else: else:
# get image file base64 # get image file base64

View File

@ -48,7 +48,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
), ),
) )
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
@ -101,7 +101,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0}, model_parameters={"max_tokens": 256, "temperature": 0},
stream=False, stream=False,
), ),
@ -110,7 +110,7 @@ class LLMGenerator:
questions = output_parser.parse(cast(str, response.message.content)) questions = output_parser.parse(cast(str, response.message.content))
except InvokeError: except InvokeError:
questions = [] questions = []
except Exception as e: except Exception:
logging.exception("Failed to generate suggested questions after answer") logging.exception("Failed to generate suggested questions after answer")
questions = [] questions = []
@ -150,7 +150,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )
@ -200,7 +200,7 @@ class LLMGenerator:
prompt_content = cast( prompt_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )
except InvokeError as e: except InvokeError as e:
@ -236,7 +236,7 @@ class LLMGenerator:
parameter_content = cast( parameter_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
), ),
) )
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
@ -248,7 +248,7 @@ class LLMGenerator:
statement_content = cast( statement_content = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
), ),
) )
rule_config["opening_statement"] = cast(str, statement_content.message.content) rule_config["opening_statement"] = cast(str, statement_content.message.content)
@ -301,7 +301,7 @@ class LLMGenerator:
response = cast( response = cast(
LLMResult, LLMResult,
model_instance.invoke_llm( model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
), ),
) )

View File

@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks, callbacks=callbacks,
) )
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try: try:
plugin_model_manager = PluginModelManager() plugin_model_manager = PluginModelManager()
result = plugin_model_manager.invoke_llm( result = plugin_model_manager.invoke_llm(

View File

@ -285,17 +285,17 @@ class ModelProviderFactory:
} }
if model_type == ModelType.LLM: if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params) return LargeLanguageModel(**init_params) # type: ignore
elif model_type == ModelType.TEXT_EMBEDDING: elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params) return TextEmbeddingModel(**init_params) # type: ignore
elif model_type == ModelType.RERANK: elif model_type == ModelType.RERANK:
return RerankModel(**init_params) return RerankModel(**init_params) # type: ignore
elif model_type == ModelType.SPEECH2TEXT: elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params) return Speech2TextModel(**init_params) # type: ignore
elif model_type == ModelType.MODERATION: elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params) return ModerationModel(**init_params) # type: ignore
elif model_type == ModelType.TTS: elif model_type == ModelType.TTS:
return TTSModel(**init_params) return TTSModel(**init_params) # type: ignore
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
""" """

View File

@ -119,7 +119,7 @@ class BasePluginManager:
Make a request to the plugin daemon inner API and return the response as a model. Make a request to the plugin daemon inner API and return the response as a model.
""" """
response = self._request(method, path, headers, data, params, files) response = self._request(method, path, headers, data, params, files)
return type(**response.json()) return type(**response.json()) # type: ignore
def _request_with_plugin_daemon_response( def _request_with_plugin_daemon_response(
self, self,
@ -140,7 +140,7 @@ class BasePluginManager:
if transformer: if transformer:
json_response = transformer(json_response) json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
if rep.code != 0: if rep.code != 0:
try: try:
error = PluginDaemonError(**json.loads(rep.message)) error = PluginDaemonError(**json.loads(rep.message))
@ -171,7 +171,7 @@ class BasePluginManager:
line_data = None line_data = None
try: try:
line_data = json.loads(line) line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data) rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
except Exception: except Exception:
# TODO modify this when line_data has code and message # TODO modify this when line_data has code and message
if line_data and "error" in line_data: if line_data and "error" in line_data:

View File

@ -742,7 +742,7 @@ class ProviderManager:
try: try:
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError: except JSONDecodeError:
provider_credentials: dict[str, Any] = {} provider_credentials = {}
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(

View File

@ -601,6 +601,9 @@ class DatasetRetrieval:
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
if retrieve_config.reranking_model is None:
raise ValueError("Reranking model is required for multiple retrieval")
tool = DatasetMultiRetrieverTool.from_dataset( tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets], dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id, tenant_id=tenant_id,

View File

@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any, **kwargs: Any,
): ):
def _token_encoder(text: str) -> int: def _token_encoder(texts: list[str]) -> list[int]:
if not text: if not texts:
return 0 return []
if embedding_model_instance: if embedding_model_instance:
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
else: else:
return GPT2Tokenizer.get_num_tokens(text) return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
if issubclass(cls, TokenTextSplitter): if issubclass(cls, TokenTextSplitter):
extra_kwargs = { extra_kwargs = {
@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
_good_splits_lengths = [] # cache the lengths of the splits _good_splits_lengths = [] # cache the lengths of the splits
s_lens = self._length_function(splits) s_lens = self._length_function(splits)
for s, s_len in zip(splits, s_lens): for s, s_len in zip(splits, s_lens):
s_len = self._length_function(s)
if s_len < self._chunk_size: if s_len < self._chunk_size:
_good_splits.append(s) _good_splits.append(s)
_good_splits_lengths.append(s_len) _good_splits_lengths.append(s_len)

View File

@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size # We now want to combine these smaller pieces into medium size
# chunks to send to the LLM. # chunks to send to the LLM.
separator_len = self._length_function(separator) separator_len = self._length_function([separator])[0]
docs = [] docs = []
current_doc: list[str] = [] current_doc: list[str] = []
@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
while total > self._chunk_overlap or ( while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
): ):
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) total -= self._length_function([current_doc[0]])[0] + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:] current_doc = current_doc[1:]
current_doc.append(d) current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0) total += _len + (separator_len if len(current_doc) > 1 else 0)
@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
raise ValueError( raise ValueError(
"Could not import transformers python package. Please install it with `pip install transformers`." "Could not import transformers python package. Please install it with `pip install transformers`."
) )
return cls(length_function=_huggingface_tokenizer_length, **kwargs) return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
@classmethod @classmethod
def from_tiktoken_encoder( def from_tiktoken_encoder(
@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
} }
kwargs = {**kwargs, **extra_kwargs} kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs) return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them.""" """Transform sequence of documents by splitting them."""

View File

@ -71,13 +71,13 @@ class Tool(ABC):
if isinstance(result, ToolInvokeMessage): if isinstance(result, ToolInvokeMessage):
def single_generator(): def single_generator() -> Generator[ToolInvokeMessage, None, None]:
yield result yield result
return single_generator() return single_generator()
elif isinstance(result, list): elif isinstance(result, list):
def generator(): def generator() -> Generator[ToolInvokeMessage, None, None]:
yield from result yield from result
return generator() return generator()

View File

@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
return self._get_builtin_tools() return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None: def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
""" """
returns the tool that the provider can provide returns the tool that the provider can provide
""" """
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
@property @property
def need_credentials(self) -> bool: def need_credentials(self) -> bool:

View File

@ -1,6 +1,7 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AudioToolProvider(BuiltinToolProviderController): class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None: def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
pass pass

View File

@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
timezone = None timezone = None
time_format = "%Y-%m-%d %H:%M:%S" time_format = "%Y-%m-%d %H:%M:%S"
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
if not timestamp: if not timestamp:
yield self.create_text_message(f"Invalid localtime: {localtime}") yield self.create_text_message(f"Invalid localtime: {localtime}")
return return
@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool):
if isinstance(local_tz, str): if isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz) local_tz = pytz.timezone(local_tz)
local_time = datetime.strptime(localtime, time_format) local_time = datetime.strptime(localtime, time_format)
localtime = local_tz.localize(local_time) localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) timestamp = int(localtime.timestamp()) # type: ignore
return timestamp return timestamp
except Exception as e: except Exception as e:
raise ToolInvokeError(str(e)) raise ToolInvokeError(str(e))

View File

@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool):
""" """
Convert timestamp to localtime Convert timestamp to localtime
""" """
timestamp = tool_parameters.get("timestamp") timestamp: int = tool_parameters.get("timestamp", 0)
timezone = tool_parameters.get("timezone", "Asia/Shanghai") timezone = tool_parameters.get("timezone", "Asia/Shanghai")
if not timezone: if not timezone:
timezone = None timezone = None

View File

@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool):
current_time = tool_parameters.get("current_time") current_time = tool_parameters.get("current_time")
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai") current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo") target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
if not target_time: if not target_time:
yield self.create_text_message( yield self.create_text_message(
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"

View File

@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WebscraperProvider(BuiltinToolProviderController): class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
pass pass

View File

@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = [] self.tools = []
@classmethod @classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType): def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = [ credentials_schema = [
ProviderConfig( ProviderConfig(
name="auth_type", name="auth_type",

View File

@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
): ):
raise ToolProviderCredentialValidationError("Invalid credentials") raise ToolProviderCredentialValidationError("Invalid credentials")
def get_tool(self, tool_name: str) -> PluginTool: def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
""" """
return tool with given name return tool with given name
""" """
@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
plugin_unique_identifier=self.plugin_unique_identifier, plugin_unique_identifier=self.plugin_unique_identifier,
) )
def get_tools(self) -> list[PluginTool]: def get_tools(self) -> list[PluginTool]: # type: ignore
""" """
get all tools get all tools
""" """

View File

@ -59,7 +59,12 @@ class PluginTool(Tool):
plugin_unique_identifier=self.plugin_unique_identifier, plugin_unique_identifier=self.plugin_unique_identifier,
) )
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
""" """
get the runtime parameters get the runtime parameters
""" """
@ -76,6 +81,9 @@ class PluginTool(Tool):
provider=self.entity.identity.provider, provider=self.entity.identity.provider,
tool=self.entity.identity.name, tool=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
) )
return self.runtime_parameters return self.runtime_parameters

View File

@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator from collections.abc import Generator
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL from yarl import URL
@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
class ToolManager: class ToolManager:
_builtin_provider_lock = Lock() _builtin_provider_lock = Lock()
_hardcoded_providers = {} _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False _builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@ -203,7 +203,7 @@ class ToolManager:
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
else: else:
builtin_provider: BuiltinToolProvider | None = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.first() .first()
@ -270,9 +270,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: Optional[list[Tool]] = controller.get_tools( controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
user_id="", tenant_id=workflow_provider.tenant_id
)
if controller_tools is None or len(controller_tools) == 0: if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
@ -747,7 +745,9 @@ class ToolManager:
# add tool labels # add tool labels
labels = ToolLabelManager.get_tool_labels(controller) labels = ToolLabelManager.get_tool_labels(controller)
return jsonable_encoder( return cast(
dict,
jsonable_encoder(
{ {
"schema_type": provider_obj.schema_type, "schema_type": provider_obj.schema_type,
"schema": provider_obj.schema, "schema": provider_obj.schema,
@ -759,6 +759,7 @@ class ToolManager:
"custom_disclaimer": provider_obj.custom_disclaimer, "custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels, "labels": labels,
} }
),
) )
@classmethod @classmethod
@ -795,7 +796,8 @@ class ToolManager:
if workflow_provider is None: if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(workflow_provider.icon) icon: dict = json.loads(workflow_provider.icon)
return icon
except Exception: except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
@ -811,7 +813,8 @@ class ToolManager:
if api_provider is None: if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found") raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
return json.loads(api_provider.icon) icon: dict = json.loads(api_provider.icon)
return icon
except Exception: except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService

View File

@ -1,5 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool):
return tools return tools
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
return [ return [
ToolParameter( ToolParameter(
name="query", name="query",
@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool):
def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL return ToolProviderType.DATASET_RETRIEVAL
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
""" """
invoke dataset retriever tool invoke dataset retriever tool
""" """

View File

@ -91,7 +91,7 @@ class ToolFileMessageTransformer:
) )
elif message.type == ToolInvokeMessage.MessageType.FILE: elif message.type == ToolInvokeMessage.MessageType.FILE:
meta = message.meta or {} meta = message.meta or {}
file = meta.get("file") file = meta.get("file", None)
if isinstance(file, File): if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE: if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None assert file.related_id is not None

View File

@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
@classmethod @classmethod
def check_is_synced( def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> bool: ):
""" """
check is synced check is synced

View File

@ -6,7 +6,6 @@ from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
user = db_provider.user user = db_provider.user
@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools return self.tools
def get_tool(self, tool_name: str) -> Optional[Tool]: def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore
""" """
get tool by name get tool by name

View File

@ -106,9 +106,9 @@ class WorkflowTool(Tool):
if outputs is None: if outputs is None:
outputs = {} outputs = {}
else: else:
outputs, files = self._extract_files(outputs) outputs, files = self._extract_files(outputs) # type: ignore
for file in files: for file in files:
yield self.create_file_message(file) yield self.create_file_message(file) # type: ignore
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs) yield self.create_json_message(outputs)
@ -217,7 +217,7 @@ class WorkflowTool(Tool):
:param result: the result :param result: the result
:return: the result, files :return: the result, files
""" """
files = [] files: list[File] = []
result = {} result = {}
for key, value in outputs.items(): for key, value in outputs.items():
if isinstance(value, list): if isinstance(value, list):
@ -238,4 +238,5 @@ class WorkflowTool(Tool):
files.append(file) files.append(file)
result[key] = value result[key] = value
return result, files return result, files

View File

@ -27,7 +27,7 @@ class AgentNode(ToolNode):
Agent Node Agent Node
""" """
_node_data_cls = AgentNodeData _node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT _node_type = NodeType.AGENT
def _run(self) -> Generator: def _run(self) -> Generator:
@ -125,7 +125,7 @@ class AgentNode(ToolNode):
""" """
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result = {} result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters: for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name) parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter: if not parameter:
@ -214,7 +214,7 @@ class AgentNode(ToolNode):
:return: :return:
""" """
node_data = cast(AgentNodeData, node_data) node_data = cast(AgentNodeData, node_data)
result = {} result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters: for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name] input = node_data.agent_parameters[parameter_name]
if input.type == "mixed": if input.type == "mixed":

View File

@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]):
db.session.close() db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
stop=stop, stop=list(stop or []),
stream=True, stream=True,
user=self.user_id, user=self.user_id,
) )

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = [] json: list[dict] = []
agent_logs: list[AgentLogEvent] = [] agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {} agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
variables: dict[str, Any] = {} variables: dict[str, Any] = {}

View File

@ -284,8 +284,6 @@ class WorkflowEntry:
user_inputs=user_inputs, user_inputs=user_inputs,
variable_pool=variable_pool, variable_pool=variable_pool,
tenant_id=tenant_id, tenant_id=tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
) )
# run node # run node

View File

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones from zoneinfo import available_timezones
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_restful import fields from flask_restful import fields # type: ignore
from configs import dify_config from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

View File

@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g: if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore current_app.login_manager._load_user() # type: ignore
return g._login_user return g._login_user # type: ignore
return None return None

View File

@ -1,7 +1,7 @@
import enum import enum
import json import json
from flask_login import UserMixin from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -56,7 +56,7 @@ class Account(UserMixin, Base):
if ta: if ta:
tenant.current_role = ta.role tenant.current_role = ta.role
else: else:
tenant = None tenant = None # type: ignore
self._current_tenant = tenant self._current_tenant = tenant

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin # type: ignore
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@ -1,6 +1,6 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional, cast
import sqlalchemy as sa import sqlalchemy as sa
from deprecated import deprecated from deprecated import deprecated
@ -48,7 +48,7 @@ class BuiltinToolProvider(Base):
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:
return json.loads(self.encrypted_credentials) return cast(dict, json.loads(self.encrypted_credentials))
class ApiToolProvider(Base): class ApiToolProvider(Base):
@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
) )
# id of the tool provider
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app # id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
# who published this tool # who published this tool
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False) description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM # llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False) llm_description = db.Column(db.Text, nullable=False)
@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base):
def description_i18n(self) -> I18nObject: def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description)) return I18nObject(**json.loads(self.description))
@property
def app(self) -> App:
return db.session.query(App).filter(App.id == self.app_id).first()
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
user_id: Mapped[str] = db.Column(StringUUID, nullable=False) user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)

View File

@ -23,7 +23,7 @@ class AgentService:
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
conversation: Conversation = ( conversation: Conversation | None = (
db.session.query(Conversation) db.session.query(Conversation)
.filter( .filter(
Conversation.id == conversation_id, Conversation.id == conversation_id,

View File

@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class ModelWithProviderEntityResponse(ModelWithProviderEntity): class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
""" """
Model with provider entity. Model with provider entity.
""" """

View File

@ -173,9 +173,8 @@ class PluginMigration:
""" """
Extract model tables. Extract model tables.
NOTE: rename google to gemini
""" """
models = [] models: list[str] = []
table_pairs = [ table_pairs = [
("providers", "provider_name"), ("providers", "provider_name"),
("provider_models", "provider_name"), ("provider_models", "provider_name"),

View File

@ -439,7 +439,7 @@ class ApiToolManageService:
tenant_id=tenant_id, tenant_id=tenant_id,
) )
) )
result = runtime_tool.validate_credentials(credentials, parameters) result = tool.validate_credentials(credentials, parameters)
except Exception as e: except Exception as e:
return {"error": str(e)} return {"error": str(e)}

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Optional, Union from typing import Optional, Union, cast
from yarl import URL from yarl import URL
@ -44,7 +44,7 @@ class ToolTransformService:
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try: try:
if isinstance(icon, str): if isinstance(icon, str):
return json.loads(icon) return cast(dict, json.loads(icon))
return icon return icon
except Exception: except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}

View File

@ -1,7 +1,7 @@
import json import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any
from sqlalchemy import or_ from sqlalchemy import or_
@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App from models.model import App
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
@ -187,7 +188,7 @@ class WorkflowToolManageService:
""" """
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
tools: Sequence[WorkflowToolProviderController] = [] tools: list[WorkflowToolProviderController] = []
for provider in db_tools: for provider in db_tools:
try: try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider)) tools.append(ToolTransformService.workflow_provider_to_controller(provider))
@ -264,7 +265,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool) return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod @classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
""" """
Get a workflow tool. Get a workflow tool.
:db_tool: the database tool :db_tool: the database tool
@ -285,8 +286,8 @@ class WorkflowToolManageService:
raise ValueError("Workflow not found") raise ValueError("Workflow not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool) tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if to_user_tool is None or len(to_user_tool) == 0: if len(workflow_tools) == 0:
raise ValueError(f"Tool {db_tool.id} not found") raise ValueError(f"Tool {db_tool.id} not found")
return { return {
@ -325,8 +326,8 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found") raise ValueError(f"Tool {workflow_tool_id} not found")
tool = ToolTransformService.workflow_provider_to_controller(db_tool) tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if to_user_tool is None or len(to_user_tool) == 0: if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found") raise ValueError(f"Tool {workflow_tool_id} not found")
return [ return [

View File

@ -67,7 +67,7 @@ def batch_create_segment_to_index_task(
for segment, tokens in zip(content, tokens_list): for segment, tokens in zip(content, tokens_list):
content = segment["content"] content = segment["content"]
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content) # type: ignore
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id) .filter(DocumentSegment.document_id == dataset_document.id)