mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 04:36:00 +08:00
fix: mypy issues
This commit is contained in:
parent
76e24d91c0
commit
f748d6c7c4
@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[dict, Generator[str | dict, None, None]]:
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> dict[str, Any] | Generator[str | dict, Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
||||
|
||||
def get_custom_credentials(self, obfuscated: bool = False):
|
||||
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
|
||||
"""
|
||||
Get custom credentials.
|
||||
|
||||
@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel):
|
||||
model_provider_factory = ModelProviderFactory(self.tenant_id)
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
|
||||
model_types = []
|
||||
model_types: list[ModelType] = []
|
||||
if model_type:
|
||||
model_types.append(model_type)
|
||||
else:
|
||||
model_types = provider_schema.supported_model_types
|
||||
model_types = list(provider_schema.supported_model_types)
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
|
||||
@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel):
|
||||
def values(self) -> Iterator[ProviderConfiguration]:
|
||||
return iter(self.configurations.values())
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
|
||||
return self.configurations.get(key, default)
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
|
||||
|
||||
class ProviderModelBundle(BaseModel):
|
||||
|
@ -20,7 +20,7 @@ class UploadFileParser:
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
|
@ -48,7 +48,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
|
||||
),
|
||||
)
|
||||
answer = cast(str, response.message.content)
|
||||
@ -101,7 +101,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={"max_tokens": 256, "temperature": 0},
|
||||
stream=False,
|
||||
),
|
||||
@ -110,7 +110,7 @@ class LLMGenerator:
|
||||
questions = output_parser.parse(cast(str, response.message.content))
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
@ -150,7 +150,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,7 +200,7 @@ class LLMGenerator:
|
||||
prompt_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
except InvokeError as e:
|
||||
@ -236,7 +236,7 @@ class LLMGenerator:
|
||||
parameter_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
|
||||
@ -248,7 +248,7 @@ class LLMGenerator:
|
||||
statement_content = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
rule_config["opening_statement"] = cast(str, statement_content.message.content)
|
||||
@ -301,7 +301,7 @@ class LLMGenerator:
|
||||
response = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel):
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelManager()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
|
@ -285,17 +285,17 @@ class ModelProviderFactory:
|
||||
}
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
return LargeLanguageModel(**init_params)
|
||||
return LargeLanguageModel(**init_params) # type: ignore
|
||||
elif model_type == ModelType.TEXT_EMBEDDING:
|
||||
return TextEmbeddingModel(**init_params)
|
||||
return TextEmbeddingModel(**init_params) # type: ignore
|
||||
elif model_type == ModelType.RERANK:
|
||||
return RerankModel(**init_params)
|
||||
return RerankModel(**init_params) # type: ignore
|
||||
elif model_type == ModelType.SPEECH2TEXT:
|
||||
return Speech2TextModel(**init_params)
|
||||
return Speech2TextModel(**init_params) # type: ignore
|
||||
elif model_type == ModelType.MODERATION:
|
||||
return ModerationModel(**init_params)
|
||||
return ModerationModel(**init_params) # type: ignore
|
||||
elif model_type == ModelType.TTS:
|
||||
return TTSModel(**init_params)
|
||||
return TTSModel(**init_params) # type: ignore
|
||||
|
||||
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
"""
|
||||
|
@ -119,7 +119,7 @@ class BasePluginManager:
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type(**response.json())
|
||||
return type(**response.json()) # type: ignore
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
@ -140,7 +140,7 @@ class BasePluginManager:
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
|
||||
rep = PluginDaemonBasicResponse[type](**json_response)
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
@ -171,7 +171,7 @@ class BasePluginManager:
|
||||
line_data = None
|
||||
try:
|
||||
line_data = json.loads(line)
|
||||
rep = PluginDaemonBasicResponse[type](**line_data)
|
||||
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
|
||||
except Exception:
|
||||
# TODO modify this when line_data has code and message
|
||||
if line_data and "error" in line_data:
|
||||
|
@ -742,7 +742,7 @@ class ProviderManager:
|
||||
try:
|
||||
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials: dict[str, Any] = {}
|
||||
provider_credentials = {}
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
|
@ -601,6 +601,9 @@ class DatasetRetrieval:
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
|
||||
if retrieve_config.reranking_model is None:
|
||||
raise ValueError("Reranking model is required for multiple retrieval")
|
||||
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
|
@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if embedding_model_instance:
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
|
||||
return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
|
||||
else:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
s_len = self._length_function(s)
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
|
@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
separator_len = self._length_function(separator)
|
||||
separator_len = self._length_function([separator])[0]
|
||||
|
||||
docs = []
|
||||
current_doc: list[str] = []
|
||||
@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
|
||||
total -= self._length_function([current_doc[0]])[0] + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. Please install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
|
||||
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
|
@ -71,13 +71,13 @@ class Tool(ABC):
|
||||
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
|
||||
def single_generator():
|
||||
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield result
|
||||
|
||||
return single_generator()
|
||||
elif isinstance(result, list):
|
||||
|
||||
def generator():
|
||||
def generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from result
|
||||
|
||||
return generator()
|
||||
|
@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None:
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
timezone = None
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone)
|
||||
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
|
||||
if not timestamp:
|
||||
yield self.create_text_message(f"Invalid localtime: {localtime}")
|
||||
return
|
||||
@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
if isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
localtime = local_tz.localize(local_time)
|
||||
timestamp = int(localtime.timestamp())
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
|
@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool):
|
||||
"""
|
||||
Convert timestamp to localtime
|
||||
"""
|
||||
timestamp = tool_parameters.get("timestamp")
|
||||
timestamp: int = tool_parameters.get("timestamp", 0)
|
||||
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
|
||||
if not timezone:
|
||||
timezone = None
|
||||
|
@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool):
|
||||
current_time = tool_parameters.get("current_time")
|
||||
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
|
||||
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
|
||||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone)
|
||||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
|
||||
if not target_time:
|
||||
yield self.create_text_message(
|
||||
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
|
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = []
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
|
@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_tool(self, tool_name: str) -> PluginTool:
|
||||
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[PluginTool]:
|
||||
def get_tools(self) -> list[PluginTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
|
@ -59,7 +59,12 @@ class PluginTool(Tool):
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
@ -76,6 +81,9 @@ class PluginTool(Tool):
|
||||
provider=self.entity.identity.provider,
|
||||
tool=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
||||
|
@ -4,7 +4,7 @@ import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers = {}
|
||||
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@ -203,7 +203,7 @@ class ToolManager:
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
else:
|
||||
builtin_provider: BuiltinToolProvider | None = (
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.first()
|
||||
@ -270,9 +270,7 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: Optional[list[Tool]] = controller.get_tools(
|
||||
user_id="", tenant_id=workflow_provider.tenant_id
|
||||
)
|
||||
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
@ -747,18 +745,21 @@ class ToolManager:
|
||||
# add tool labels
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider_obj.schema_type,
|
||||
"schema": provider_obj.schema,
|
||||
"tools": provider_obj.tools,
|
||||
"icon": icon,
|
||||
"description": provider_obj.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_obj.privacy_policy,
|
||||
"custom_disclaimer": provider_obj.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
return cast(
|
||||
dict,
|
||||
jsonable_encoder(
|
||||
{
|
||||
"schema_type": provider_obj.schema_type,
|
||||
"schema": provider_obj.schema,
|
||||
"tools": provider_obj.tools,
|
||||
"icon": icon,
|
||||
"description": provider_obj.description,
|
||||
"credentials": masked_credentials,
|
||||
"privacy_policy": provider_obj.privacy_policy,
|
||||
"custom_disclaimer": provider_obj.custom_disclaimer,
|
||||
"labels": labels,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -795,7 +796,8 @@ class ToolManager:
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return json.loads(workflow_provider.icon)
|
||||
icon: dict = json.loads(workflow_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@ -811,7 +813,8 @@ class ToolManager:
|
||||
if api_provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
return json.loads(api_provider.icon)
|
||||
icon: dict = json.loads(api_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
return tools
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool):
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke dataset retriever tool
|
||||
"""
|
||||
|
@ -91,7 +91,7 @@ class ToolFileMessageTransformer:
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file")
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
|
@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
) -> bool:
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
|
@ -6,7 +6,6 @@ from pydantic import Field
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||
|
||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
||||
|
||||
user = db_provider.user
|
||||
|
||||
@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
|
@ -106,9 +106,9 @@ class WorkflowTool(Tool):
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs)
|
||||
outputs, files = self._extract_files(outputs) # type: ignore
|
||||
for file in files:
|
||||
yield self.create_file_message(file)
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
@ -217,7 +217,7 @@ class WorkflowTool(Tool):
|
||||
:param result: the result
|
||||
:return: the result, files
|
||||
"""
|
||||
files = []
|
||||
files: list[File] = []
|
||||
result = {}
|
||||
for key, value in outputs.items():
|
||||
if isinstance(value, list):
|
||||
@ -238,4 +238,5 @@ class WorkflowTool(Tool):
|
||||
files.append(file)
|
||||
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
|
@ -27,7 +27,7 @@ class AgentNode(ToolNode):
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
_node_data_cls = AgentNodeData
|
||||
_node_data_cls = AgentNodeData # type: ignore
|
||||
_node_type = NodeType.AGENT
|
||||
|
||||
def _run(self) -> Generator:
|
||||
@ -125,7 +125,7 @@ class AgentNode(ToolNode):
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
@ -214,7 +214,7 @@ class AgentNode(ToolNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(AgentNodeData, node_data)
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
input = node_data.agent_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
|
@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
|
||||
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
@ -284,8 +284,6 @@ class WorkflowEntry:
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
)
|
||||
|
||||
# run node
|
||||
|
@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields
|
||||
from flask_restful import fields # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
|
@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None:
|
||||
if "_login_user" not in g:
|
||||
current_app.login_manager._load_user() # type: ignore
|
||||
|
||||
return g._login_user
|
||||
return g._login_user # type: ignore
|
||||
|
||||
return None
|
||||
|
@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
|
||||
from flask_login import UserMixin
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@ -56,7 +56,7 @@ class Account(UserMixin, Base):
|
||||
if ta:
|
||||
tenant.current_role = ta.role
|
||||
else:
|
||||
tenant = None
|
||||
tenant = None # type: ignore
|
||||
|
||||
self._current_tenant = tenant
|
||||
|
||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
@ -48,7 +48,7 @@ class BuiltinToolProvider(Base):
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
return json.loads(self.encrypted_credentials)
|
||||
return cast(dict, json.loads(self.encrypted_credentials))
|
||||
|
||||
|
||||
class ApiToolProvider(Base):
|
||||
@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base):
|
||||
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
# id of the app
|
||||
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
# who published this tool
|
||||
user_id = db.Column(StringUUID, nullable=False)
|
||||
# description of the tool, stored in i18n format, for human
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = db.Column(db.Text, nullable=False)
|
||||
@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base):
|
||||
def description_i18n(self) -> I18nObject:
|
||||
return I18nObject(**json.loads(self.description))
|
||||
|
||||
@property
|
||||
def app(self) -> App:
|
||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
|
@ -23,7 +23,7 @@ class AgentService:
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
conversation: Conversation = (
|
||||
conversation: Conversation | None = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
|
@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
@ -173,9 +173,8 @@ class PluginMigration:
|
||||
"""
|
||||
Extract model tables.
|
||||
|
||||
NOTE: rename google to gemini
|
||||
"""
|
||||
models = []
|
||||
models: list[str] = []
|
||||
table_pairs = [
|
||||
("providers", "provider_name"),
|
||||
("provider_models", "provider_name"),
|
||||
|
@ -439,7 +439,7 @@ class ApiToolManageService:
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
result = runtime_tool.validate_credentials(credentials, parameters)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@ -44,7 +44,7 @@ class ToolTransformService:
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return cast(dict, json.loads(icon))
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
||||
@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
@ -187,7 +188,7 @@ class WorkflowToolManageService:
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
tools: Sequence[WorkflowToolProviderController] = []
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
try:
|
||||
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
|
||||
@ -264,7 +265,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
@ -285,8 +286,8 @@ class WorkflowToolManageService:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {db_tool.id} not found")
|
||||
|
||||
return {
|
||||
@ -325,8 +326,8 @@ class WorkflowToolManageService:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id)
|
||||
if to_user_tool is None or len(to_user_tool) == 0:
|
||||
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
return [
|
||||
|
@ -67,7 +67,7 @@ def batch_create_segment_to_index_task(
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
segment_hash = helper.generate_text_hash(content) # type: ignore
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.filter(DocumentSegment.document_id == dataset_document.id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user