mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-10 05:09:02 +08:00
chore: avoid implicit optional in type annotations of method (#8727)
This commit is contained in:
parent
b360feb4c1
commit
240b66d737
@ -369,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
def _organize_historic_prompt_messages(
|
def _organize_historic_prompt_messages(
|
||||||
self, current_session_messages: list[PromptMessage] = None
|
self, current_session_messages: Optional[list[PromptMessage]] = None
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
organize historic prompt messages
|
organize historic prompt messages
|
||||||
|
@ -27,7 +27,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||||||
|
|
||||||
return SystemPromptMessage(content=system_prompt)
|
return SystemPromptMessage(content=system_prompt)
|
||||||
|
|
||||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.agent.cot_agent_runner import CotAgentRunner
|
from core.agent.cot_agent_runner import CotAgentRunner
|
||||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||||
@ -21,7 +22,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
|||||||
|
|
||||||
return system_prompt
|
return system_prompt
|
||||||
|
|
||||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Organize historic prompt
|
Organize historic prompt
|
||||||
"""
|
"""
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from core.agent.base_agent_runner import BaseAgentRunner
|
from core.agent.base_agent_runner import BaseAgentRunner
|
||||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||||
@ -370,7 +370,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
def _init_system_message(
|
def _init_system_message(
|
||||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Initialize system message
|
Initialize system message
|
||||||
@ -385,7 +385,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
|
@ -211,9 +211,9 @@ class IndexingRunner:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
extract_settings: list[ExtractSetting],
|
extract_settings: list[ExtractSetting],
|
||||||
tmp_processing_rule: dict,
|
tmp_processing_rule: dict,
|
||||||
doc_form: str = None,
|
doc_form: Optional[str] = None,
|
||||||
doc_language: str = "English",
|
doc_language: str = "English",
|
||||||
dataset_id: str = None,
|
dataset_id: Optional[str] = None,
|
||||||
indexing_technique: str = "economy",
|
indexing_technique: str = "economy",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
|
@ -169,7 +169,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: list[Callback] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
|
@ -92,7 +92,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: list[Callback] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
|
@ -511,7 +511,7 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
|||||||
model: str,
|
model: str,
|
||||||
messages: list[PromptMessage],
|
messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
credentials: dict = None,
|
credentials: Optional[dict] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Approximate num tokens with GPT2 tokenizer.
|
Approximate num tokens with GPT2 tokenizer.
|
||||||
|
@ -111,7 +111,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: list[Callback] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
|
@ -688,7 +688,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
model: str,
|
model: str,
|
||||||
messages: list[PromptMessage],
|
messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
credentials: dict = None,
|
credentials: Optional[dict] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Approximate num tokens with GPT2 tokenizer.
|
Approximate num tokens with GPT2 tokenizer.
|
||||||
|
@ -77,7 +77,7 @@ class SageMakerText2SpeechModel(TTSModel):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _detect_lang_code(self, content: str, map_dict: dict = None):
|
def _detect_lang_code(self, content: str, map_dict: Optional[dict] = None):
|
||||||
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
||||||
|
|
||||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||||
|
@ -64,7 +64,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
|||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: list[Callback] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Code block mode wrapper for invoking large language model
|
Code block mode wrapper for invoking large language model
|
||||||
|
@ -41,8 +41,8 @@ class Assistant(BaseAPI):
|
|||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
request_id: str = None,
|
request_id: Optional[str] = None,
|
||||||
user_id: str = None,
|
user_id: Optional[str] = None,
|
||||||
extra_headers: Headers | None = None,
|
extra_headers: Headers | None = None,
|
||||||
extra_body: Body | None = None,
|
extra_body: Body | None = None,
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
@ -72,9 +72,9 @@ class Assistant(BaseAPI):
|
|||||||
def query_support(
|
def query_support(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
assistant_id_list: list[str] = None,
|
assistant_id_list: Optional[list[str]] = None,
|
||||||
request_id: str = None,
|
request_id: Optional[str] = None,
|
||||||
user_id: str = None,
|
user_id: Optional[str] = None,
|
||||||
extra_headers: Headers | None = None,
|
extra_headers: Headers | None = None,
|
||||||
extra_body: Body | None = None,
|
extra_body: Body | None = None,
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
@ -99,8 +99,8 @@ class Assistant(BaseAPI):
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 10,
|
page_size: int = 10,
|
||||||
*,
|
*,
|
||||||
request_id: str = None,
|
request_id: Optional[str] = None,
|
||||||
user_id: str = None,
|
user_id: Optional[str] = None,
|
||||||
extra_headers: Headers | None = None,
|
extra_headers: Headers | None = None,
|
||||||
extra_body: Body | None = None,
|
extra_body: Body | None = None,
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Literal, cast
|
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@ -34,11 +34,11 @@ class Files(BaseAPI):
|
|||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
file: FileTypes = None,
|
file: Optional[FileTypes] = None,
|
||||||
upload_detail: list[UploadDetail] = None,
|
upload_detail: Optional[list[UploadDetail]] = None,
|
||||||
purpose: Literal["fine-tune", "retrieval", "batch"],
|
purpose: Literal["fine-tune", "retrieval", "batch"],
|
||||||
knowledge_id: str = None,
|
knowledge_id: Optional[str] = None,
|
||||||
sentence_size: int = None,
|
sentence_size: Optional[int] = None,
|
||||||
extra_headers: Headers | None = None,
|
extra_headers: Headers | None = None,
|
||||||
extra_body: Body | None = None,
|
extra_body: Body | None = None,
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
@ -34,12 +34,12 @@ class Document(BaseAPI):
|
|||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
file: FileTypes = None,
|
file: Optional[FileTypes] = None,
|
||||||
custom_separator: Optional[list[str]] = None,
|
custom_separator: Optional[list[str]] = None,
|
||||||
upload_detail: list[UploadDetail] = None,
|
upload_detail: Optional[list[UploadDetail]] = None,
|
||||||
purpose: Literal["retrieval"],
|
purpose: Literal["retrieval"],
|
||||||
knowledge_id: str = None,
|
knowledge_id: Optional[str] = None,
|
||||||
sentence_size: int = None,
|
sentence_size: Optional[int] = None,
|
||||||
extra_headers: Headers | None = None,
|
extra_headers: Headers | None = None,
|
||||||
extra_body: Body | None = None,
|
extra_body: Body | None = None,
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
@ -31,11 +31,11 @@ class Videos(BaseAPI):
|
|||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
*,
|
||||||
prompt: str = None,
|
prompt: Optional[str] = None,
|
||||||
image_url: str = None,
|
image_url: Optional[str] = None,
|
||||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||||
request_id: str = None,
|
request_id: Optional[str] = None,
|
||||||
user_id: str = None,
|
user_id: Optional[str] = None,
|
||||||
extra_headers: Headers | None = None,
|
extra_headers: Headers | None = None,
|
||||||
extra_body: Body | None = None,
|
extra_body: Body | None = None,
|
||||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
@ -162,7 +162,7 @@ class RelytVector(BaseVector):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_uuids(self, ids: list[str] = None):
|
def delete_by_uuids(self, ids: Optional[list[str]] = None):
|
||||||
"""Delete by vector IDs.
|
"""Delete by vector IDs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.embedding.cached_embedding import CacheEmbedding
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
@ -25,7 +25,7 @@ class AbstractVectorFactory(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class Vector:
|
class Vector:
|
||||||
def __init__(self, dataset: Dataset, attributes: list = None):
|
def __init__(self, dataset: Dataset, attributes: Optional[list] = None):
|
||||||
if attributes is None:
|
if attributes is None:
|
||||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
@ -106,7 +106,7 @@ class Vector:
|
|||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
|
|
||||||
def create(self, texts: list = None, **kwargs):
|
def create(self, texts: Optional[list] = None, **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
|
||||||
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -84,7 +84,7 @@ class ExtractProcessor:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract(
|
def extract(
|
||||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None
|
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: Optional[str] = None
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.rag.extractor.extractor_base import BaseExtractor
|
from core.rag.extractor.extractor_base import BaseExtractor
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
@ -17,7 +18,7 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
api_url: str = None,
|
api_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize with file path."""
|
"""Initialize with file path."""
|
||||||
self._file_path = file_path
|
self._file_path = file_path
|
||||||
|
@ -341,7 +341,7 @@ class ToolRuntimeVariablePool(BaseModel):
|
|||||||
|
|
||||||
self.pool.append(variable)
|
self.pool.append(variable)
|
||||||
|
|
||||||
def set_file(self, tool_name: str, value: str, name: str = None) -> None:
|
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
set an image variable
|
set an image variable
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class SageMakerTTSTool(BuiltinTool):
|
|||||||
s3_client: Any = None
|
s3_client: Any = None
|
||||||
comprehend_client: Any = None
|
comprehend_client: Any = None
|
||||||
|
|
||||||
def _detect_lang_code(self, content: str, map_dict: dict = None):
|
def _detect_lang_code(self, content: str, map_dict: Optional[dict] = None):
|
||||||
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
||||||
|
|
||||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult
|
||||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
@ -124,7 +126,7 @@ class BuiltinTool(Tool):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
def get_url(self, url: str, user_agent: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
get url
|
get url
|
||||||
"""
|
"""
|
||||||
|
@ -318,7 +318,7 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as)
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as)
|
||||||
|
|
||||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = "") -> ToolInvokeMessage:
|
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
create a blob message
|
create a blob message
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import mimetypes
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from os import listdir, path
|
from os import listdir, path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
@ -72,7 +72,7 @@ class ToolManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool(
|
def get_tool(
|
||||||
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None
|
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
|
||||||
) -> Union[BuiltinTool, ApiTool]:
|
) -> Union[BuiltinTool, ApiTool]:
|
||||||
"""
|
"""
|
||||||
get the tool
|
get the tool
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
@ -32,7 +34,12 @@ class FeishuRequest:
|
|||||||
return res.get("tenant_access_token")
|
return res.get("tenant_access_token")
|
||||||
|
|
||||||
def _send_request(
|
def _send_request(
|
||||||
self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None
|
self,
|
||||||
|
url: str,
|
||||||
|
method: str = "post",
|
||||||
|
require_token: bool = True,
|
||||||
|
payload: Optional[dict] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -3,6 +3,7 @@ import uuid
|
|||||||
from json import dumps as json_dumps
|
from json import dumps as json_dumps
|
||||||
from json import loads as json_loads
|
from json import loads as json_loads
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from requests import get
|
from requests import get
|
||||||
from yaml import YAMLError, safe_load
|
from yaml import YAMLError, safe_load
|
||||||
@ -16,7 +17,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
|
|||||||
class ApiBasedToolSchemaParser:
|
class ApiBasedToolSchemaParser:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_openapi_to_tool_bundle(
|
def parse_openapi_to_tool_bundle(
|
||||||
openapi: dict, extra_info: dict = None, warning: dict = None
|
openapi: dict, extra_info: Optional[dict], warning: Optional[dict]
|
||||||
) -> list[ApiToolBundle]:
|
) -> list[ApiToolBundle]:
|
||||||
warning = warning if warning is not None else {}
|
warning = warning if warning is not None else {}
|
||||||
extra_info = extra_info if extra_info is not None else {}
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
@ -174,7 +175,7 @@ class ApiBasedToolSchemaParser:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_openapi_yaml_to_tool_bundle(
|
def parse_openapi_yaml_to_tool_bundle(
|
||||||
yaml: str, extra_info: dict = None, warning: dict = None
|
yaml: str, extra_info: Optional[dict], warning: Optional[dict]
|
||||||
) -> list[ApiToolBundle]:
|
) -> list[ApiToolBundle]:
|
||||||
"""
|
"""
|
||||||
parse openapi yaml to tool bundle
|
parse openapi yaml to tool bundle
|
||||||
@ -191,7 +192,7 @@ class ApiBasedToolSchemaParser:
|
|||||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
|
def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: Optional[dict]) -> dict:
|
||||||
"""
|
"""
|
||||||
parse swagger to openapi
|
parse swagger to openapi
|
||||||
|
|
||||||
@ -253,7 +254,7 @@ class ApiBasedToolSchemaParser:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_openai_plugin_json_to_tool_bundle(
|
def parse_openai_plugin_json_to_tool_bundle(
|
||||||
json: str, extra_info: dict = None, warning: dict = None
|
json: str, extra_info: Optional[dict], warning: Optional[dict]
|
||||||
) -> list[ApiToolBundle]:
|
) -> list[ApiToolBundle]:
|
||||||
"""
|
"""
|
||||||
parse openapi plugin yaml to tool bundle
|
parse openapi plugin yaml to tool bundle
|
||||||
@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def auto_parse_to_tool_bundle(
|
def auto_parse_to_tool_bundle(
|
||||||
content: str, extra_info: dict = None, warning: dict = None
|
content: str, extra_info: Optional[dict], warning: Optional[dict]
|
||||||
) -> tuple[list[ApiToolBundle], str]:
|
) -> tuple[list[ApiToolBundle], str]:
|
||||||
"""
|
"""
|
||||||
auto parse to tool bundle
|
auto parse to tool bundle
|
||||||
|
@ -9,6 +9,7 @@ import tempfile
|
|||||||
import unicodedata
|
import unicodedata
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import chardet
|
import chardet
|
||||||
@ -36,7 +37,7 @@ def page_result(text: str, cursor: int, max_length: int) -> str:
|
|||||||
return text[cursor : cursor + max_length]
|
return text[cursor : cursor + max_length]
|
||||||
|
|
||||||
|
|
||||||
def get_url(url: str, user_agent: str = None) -> str:
|
def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||||
"""Fetch URL and return the contents as a string."""
|
"""Fetch URL and return the contents as a string."""
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||||
|
@ -189,7 +189,7 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp
|
|||||||
|
|
||||||
class TokenManager:
|
class TokenManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
def generate_token(cls, account: Account, token_type: str, additional_data: Optional[dict] = None) -> str:
|
||||||
old_token = cls._get_current_token_for_account(account.id, token_type)
|
old_token = cls._get_current_token_for_account(account.id, token_type)
|
||||||
if old_token:
|
if old_token:
|
||||||
if isinstance(old_token, bytes):
|
if isinstance(old_token, bytes):
|
||||||
|
@ -28,6 +28,7 @@ select = [
|
|||||||
"PLR0402", # manual-from-import
|
"PLR0402", # manual-from-import
|
||||||
"PLR1711", # useless-return
|
"PLR1711", # useless-return
|
||||||
"PLR1714", # repeated-equality-comparison
|
"PLR1714", # repeated-equality-comparison
|
||||||
|
"RUF013", # implicit-optional
|
||||||
"RUF019", # unnecessary-key-check
|
"RUF019", # unnecessary-key-check
|
||||||
"RUF100", # unused-noqa
|
"RUF100", # unused-noqa
|
||||||
"RUF101", # redirected-noqa
|
"RUF101", # redirected-noqa
|
||||||
|
@ -321,7 +321,7 @@ class TenantService:
|
|||||||
return tenant
|
return tenant
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def switch_tenant(account: Account, tenant_id: int = None) -> None:
|
def switch_tenant(account: Account, tenant_id: Optional[int] = None) -> None:
|
||||||
"""Switch the current workspace for the account"""
|
"""Switch the current workspace for the account"""
|
||||||
|
|
||||||
# Ensure tenant_id is provided
|
# Ensure tenant_id is provided
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseServiceError(Exception):
|
class BaseServiceError(Exception):
|
||||||
def __init__(self, description: str = None):
|
def __init__(self, description: Optional[str] = None):
|
||||||
self.description = description
|
self.description = description
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
@ -11,7 +12,7 @@ from models.model import App, Tag, TagBinding
|
|||||||
|
|
||||||
class TagService:
|
class TagService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
|
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
|
||||||
query = (
|
query = (
|
||||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from httpx import get
|
from httpx import get
|
||||||
|
|
||||||
@ -79,7 +80,7 @@ class ApiToolManageService:
|
|||||||
raise ValueError(f"invalid schema: {str(e)}")
|
raise ValueError(f"invalid schema: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
def convert_schema_to_tool_bundles(schema: str, extra_info: Optional[dict] = None) -> list[ApiToolBundle]:
|
||||||
"""
|
"""
|
||||||
convert schema to tool bundles
|
convert schema to tool bundles
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ class ToolTransformService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def workflow_provider_to_user_provider(
|
def workflow_provider_to_user_provider(
|
||||||
provider_controller: WorkflowToolProviderController, labels: list[str] = None
|
provider_controller: WorkflowToolProviderController, labels: Optional[list[str]] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
convert provider controller to user provider
|
convert provider controller to user provider
|
||||||
@ -174,7 +174,7 @@ class ToolTransformService:
|
|||||||
provider_controller: ApiToolProviderController,
|
provider_controller: ApiToolProviderController,
|
||||||
db_provider: ApiToolProvider,
|
db_provider: ApiToolProvider,
|
||||||
decrypt_credentials: bool = True,
|
decrypt_credentials: bool = True,
|
||||||
labels: list[str] = None,
|
labels: Optional[list[str]] = None,
|
||||||
) -> UserToolProvider:
|
) -> UserToolProvider:
|
||||||
"""
|
"""
|
||||||
convert provider controller to user provider
|
convert provider controller to user provider
|
||||||
@ -223,9 +223,9 @@ class ToolTransformService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_to_user_tool(
|
def tool_to_user_tool(
|
||||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||||
credentials: dict = None,
|
credentials: Optional[dict] = None,
|
||||||
tenant_id: str = None,
|
tenant_id: Optional[str] = None,
|
||||||
labels: list[str] = None,
|
labels: Optional[list[str]] = None,
|
||||||
) -> UserTool:
|
) -> UserTool:
|
||||||
"""
|
"""
|
||||||
convert tool to user tool
|
convert tool to user tool
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ class WorkflowToolManageService:
|
|||||||
description: str,
|
description: str,
|
||||||
parameters: list[dict],
|
parameters: list[dict],
|
||||||
privacy_policy: str = "",
|
privacy_policy: str = "",
|
||||||
labels: list[str] = None,
|
labels: Optional[list[str]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Create a workflow tool.
|
Create a workflow tool.
|
||||||
@ -106,7 +107,7 @@ class WorkflowToolManageService:
|
|||||||
description: str,
|
description: str,
|
||||||
parameters: list[dict],
|
parameters: list[dict],
|
||||||
privacy_policy: str = "",
|
privacy_policy: str = "",
|
||||||
labels: list[str] = None,
|
labels: Optional[list[str]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Update a workflow tool.
|
Update a workflow tool.
|
||||||
|
@ -48,7 +48,7 @@ class MockTcvectordbClass:
|
|||||||
description: str,
|
description: str,
|
||||||
index: Index,
|
index: Index,
|
||||||
embedding: Embedding = None,
|
embedding: Embedding = None,
|
||||||
timeout: float = None,
|
timeout: Optional[float] = None,
|
||||||
) -> Collection:
|
) -> Collection:
|
||||||
return Collection(
|
return Collection(
|
||||||
self,
|
self,
|
||||||
@ -97,9 +97,9 @@ class MockTcvectordbClass:
|
|||||||
|
|
||||||
def collection_delete(
|
def collection_delete(
|
||||||
self,
|
self,
|
||||||
document_ids: list[str] = None,
|
document_ids: Optional[list[str]] = None,
|
||||||
filter: Filter = None,
|
filter: Filter = None,
|
||||||
timeout: float = None,
|
timeout: Optional[float] = None,
|
||||||
):
|
):
|
||||||
return {"code": 0, "msg": "operation success"}
|
return {"code": 0, "msg": "operation success"}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user