From dcb72e0067b4416161d92ab183f841f5ea4dcadb Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Thu, 27 Jun 2024 11:21:31 +0800 Subject: [PATCH] chore: apply flake8-comprehensions Ruff rules to improve collection comprehensions (#5652) Co-authored-by: -LAN- --- .../easy_ui_based_app/agent/manager.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_generator.py | 2 +- api/core/app/apps/agent_chat/app_runner.py | 2 +- api/core/app/apps/chat/app_generator.py | 2 +- api/core/entities/provider_configuration.py | 4 +-- api/core/indexing_runner.py | 2 +- .../model_providers/azure_openai/tts/tts.py | 2 +- .../model_providers/bedrock/llm/llm.py | 10 ++++---- .../bedrock/text_embedding/text_embedding.py | 6 ++--- .../model_providers/cohere/llm/llm.py | 8 +++--- .../model_providers/google/llm/llm.py | 5 +--- .../model_providers/moonshot/llm/llm.py | 4 +-- .../model_providers/nvidia/llm/llm.py | 8 +++--- .../model_providers/openai/llm/llm.py | 8 +++--- .../model_providers/openai/tts/tts.py | 2 +- .../model_providers/replicate/llm/llm.py | 15 ++++++----- .../model_providers/tongyi/tts/tts.py | 2 +- .../model_providers/vertex_ai/llm/llm.py | 5 +--- .../volcengine_maas/volc_sdk/base/auth.py | 2 +- .../volcengine_maas/volc_sdk/base/service.py | 2 +- .../model_providers/xinference/llm/llm.py | 2 +- .../model_providers/zhipuai/_common.py | 2 +- api/core/prompt/simple_prompt_transform.py | 6 ++--- api/core/provider_manager.py | 6 ++--- .../rag/datasource/keyword/jieba/jieba.py | 2 +- .../router/multi_dataset_react_route.py | 2 +- .../builtin/bing/tools/bing_web_search.py | 8 +++--- .../tools/provider/builtin/chart/tools/bar.py | 4 +-- .../provider/builtin/chart/tools/line.py | 4 +-- .../tools/provider/builtin/chart/tools/pie.py | 4 +-- .../builtin/gaode/tools/gaode_weather.py | 4 +-- .../github/tools/github_repositories.py | 4 +-- .../builtin/jina/tools/jina_reader.py | 6 ++--- .../builtin/jina/tools/jina_search.py | 2 +- .../builtin/searchapi/tools/google.py | 2 +- .../builtin/searchapi/tools/google_jobs.py | 4 +-- .../builtin/searchapi/tools/google_news.py | 2 +- .../builtin/searxng/tools/searxng_search.py | 2 +- .../builtin/websearch/tools/get_markdown.py | 2 +- .../builtin/websearch/tools/job_search.py | 4 +-- .../builtin/websearch/tools/news_search.py | 2 +- .../builtin/websearch/tools/scholar_search.py | 2 +- .../tools/provider/builtin_tool_provider.py | 2 +- api/core/tools/tool/api_tool.py | 4 +-- api/core/tools/tool/dataset_retriever_tool.py | 2 +- api/core/tools/tool_manager.py | 2 +- api/core/tools/utils/parser.py | 19 ++++++-------- api/core/tools/utils/web_reader_tool.py | 2 +- .../parameter_extractor_node.py | 4 +-- api/libs/oauth_data_source.py | 10 ++------ api/pyproject.toml | 25 ++++++++++++------- api/services/dataset_service.py | 2 +- api/services/recommended_app_service.py | 2 +- api/services/workflow/workflow_converter.py | 4 +-- .../workflow/nodes/test_llm.py | 4 +-- .../nodes/test_parameter_extractor.py | 2 +- .../prompt/test_simple_prompt_transform.py | 4 +-- 58 files changed, 123 insertions(+), 136 deletions(-) diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index f271aeed0c..dc65d4439b 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -40,7 +40,7 @@ class AgentConfigManager: 'provider_type': tool['provider_type'], 'provider_id': tool['provider_id'], 'tool_name': tool['tool_name'], - 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} + 'tool_parameters': tool.get('tool_parameters', {}) } agent_tools.append(AgentToolEntity(**agent_tool_properties)) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 2fcc325540..84723cb5c7 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -59,7 +59,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): inputs = args['inputs'] extras = { - "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False + "auto_generate_conversation_name": args.get('auto_generate_name', False) } # get conversation diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index a9beeb3a5c..df6a35918b 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -57,7 +57,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): inputs = args['inputs'] extras = { - "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + "auto_generate_conversation_name": args.get('auto_generate_name', True) } # get conversation diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 6aa615a48d..d1bbf679c5 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -203,7 +203,7 @@ class AgentChatAppRunner(AppRunner): llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): + if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 94e862cb87..5b896e2845 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): inputs = args['inputs'] extras = { - "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True + "auto_generate_conversation_name": args.get('auto_generate_name', True) } # get conversation diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 564dfd8973..f3cf54a58e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -66,8 +66,8 @@ class ProviderConfiguration(BaseModel): original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any([len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations]) + if (any(len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations) and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index af4bed13ef..826edff608 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -397,7 +397,7 @@ class IndexingRunner: document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ - DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), + DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) } ) diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 585b061afe..dcd154cff0 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -83,7 +83,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): max_workers = self._get_model_workers_limit(model, credentials) try: sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) - audio_bytes_list = list() + audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index dad5120d83..3756aa2fdc 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -175,8 +175,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock # - https://github.com/anthropics/anthropic-sdk-python client = AnthropicBedrock( - aws_access_key=credentials.get("aws_access_key_id", None), - aws_secret_key=credentials.get("aws_secret_access_key", None), + aws_access_key=credentials.get("aws_access_key_id"), + aws_secret_key=credentials.get("aws_secret_access_key"), aws_region=credentials["aws_region"], ) @@ -576,7 +576,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): """ Create payload for bedrock api call depending on model provider """ - payload = dict() + payload = {} model_prefix = model.split('.')[0] model_name = model.split('.')[1] @@ -648,8 +648,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): runtime_client = boto3.client( service_name='bedrock-runtime', config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id", None), - aws_secret_access_key=credentials.get("aws_secret_access_key", None) + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key") ) model_prefix = model.split('.')[0] diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 35b1a8f389..993416cdc8 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -49,8 +49,8 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): bedrock_runtime = boto3.client( service_name='bedrock-runtime', config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id", None), - aws_secret_access_key=credentials.get("aws_secret_access_key", None) + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key") ) embeddings = [] @@ -148,7 +148,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): """ Create payload for bedrock api call depending on model provider """ - payload = dict() + payload = {} if model_prefix == "amazon": payload['inputText'] = texts diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index f9fae5e8ca..89b04c0279 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -696,12 +696,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): en_US=model ), model_type=ModelType.LLM, - features=[feature for feature in base_model_schema_features], + features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - key: property for key, property in base_model_schema_model_properties.items() - }, - parameter_rules=[rule for rule in base_model_schema_parameters_rules], + model_properties=dict(base_model_schema_model_properties.items()), + parameter_rules=list(base_model_schema_parameters_rules), pricing=base_model_schema.pricing ) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index c934c54634..ebcd0af35b 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -277,10 +277,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): type='function', function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps({ - key: value - for key, value in part.function_call.args.items() - }) + arguments=json.dumps(dict(part.function_call.args.items())) ) ) ] diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 3e146559c8..ef301b0f6c 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -88,9 +88,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and set([ + if model_schema and { ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - ]).intersection(model_schema.features or []): + }.intersection(model_schema.features or []): credentials['function_calling_type'] = 'tool_call' def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 047bbeda63..4b2dbf3d3a 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -100,10 +100,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None + endpoint_url = credentials.get('endpoint_url') if endpoint_url and not endpoint_url.endswith('/'): endpoint_url += '/' - server_url = credentials['server_url'] if 'server_url' in credentials else None + server_url = credentials.get('server_url') # prepare the payload for a simple ping to the model data = { @@ -182,10 +182,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): if stream: headers['Accept'] = 'text/event-stream' - endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None + endpoint_url = credentials.get('endpoint_url') if endpoint_url and not endpoint_url.endswith('/'): endpoint_url += '/' - server_url = credentials['server_url'] if 'server_url' in credentials else None + server_url = credentials.get('server_url') data = { "model": model, diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 69afabadb3..aae2729bdf 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1073,12 +1073,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): en_US=model ), model_type=ModelType.LLM, - features=[feature for feature in base_model_schema_features], + features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - key: property for key, property in base_model_schema_model_properties.items() - }, - parameter_rules=[rule for rule in base_model_schema_parameters_rules], + model_properties=dict(base_model_schema_model_properties.items()), + parameter_rules=list(base_model_schema_parameters_rules), pricing=base_model_schema.pricing ) diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index f5e2ec4b7c..f83c57078a 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -80,7 +80,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): max_workers = self._get_model_workers_limit(model, credentials) try: sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) - audio_bytes_list = list() + audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index f4198dbfa7..31b81a829e 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -275,14 +275,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): @classmethod def _get_parameter_type(cls, param_type: str) -> str: - if param_type == 'integer': - return 'int' - elif param_type == 'number': - return 'float' - elif param_type == 'boolean': - return 'boolean' - elif param_type == 'string': - return 'string' + type_mapping = { + 'integer': 'int', + 'number': 'float', + 'boolean': 'boolean', + 'string': 'string' + } + return type_mapping.get(param_type) def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index b00f7c7c93..7ef053479b 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -80,7 +80,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): max_workers = self._get_model_workers_limit(model, credentials) try: sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) - audio_bytes_list = list() + audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 804c3535fb..8901549110 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -579,10 +579,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): type='function', function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps({ - key: value - for key, value in part.function_call.args.items() - }) + arguments=json.dumps(dict(part.function_call.args.items())) ) ) ] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py index 48110f16d7..053432a089 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py @@ -102,7 +102,7 @@ class Signer: body_hash = Util.sha256(request.body) request.headers['X-Content-Sha256'] = body_hash - signed_headers = dict() + signed_headers = {} for key in request.headers: if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): signed_headers[key.lower()] = request.headers[key] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py index 03734ec54f..7271ae63fd 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py @@ -150,7 +150,7 @@ class Request: self.headers = OrderedDict() self.query = OrderedDict() self.body = '' - self.form = dict() + self.form = {} self.connection_timeout = 0 self.socket_timeout = 0 diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index cc3ce17975..637e9b32e6 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -147,7 +147,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): return self._get_num_tokens_by_gpt2(text) if is_completion_model: - return sum([tokens(str(message.content)) for message in messages]) + return sum(tokens(str(message.content)) for message in messages) tokens_per_message = 3 tokens_per_name = 1 diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 2574234abf..3412d8100f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -18,7 +18,7 @@ class _CommonZhipuaiAI: """ credentials_kwargs = { "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None, + credentials.get("zhipuai_api_key"), } return credentials_kwargs diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 9b0c96b8bf..452b270348 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -148,7 +148,7 @@ class SimplePromptTransform(PromptTransform): special_variable_keys.append('#histories#') if query_in_prompt: - prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' + prompt += prompt_rules.get('query_prompt', '{{#query#}}') special_variable_keys.append('#query#') return { @@ -234,8 +234,8 @@ class SimplePromptTransform(PromptTransform): ) ), max_token_limit=rest_tokens, - human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + human_prefix=prompt_rules.get('human_prefix', 'Human'), + ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') ) # get prompt diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c9447a79df..c0b3746e18 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -417,7 +417,7 @@ class ProviderManager: model_load_balancing_enabled = cache_result == 'True' if not model_load_balancing_enabled: - return dict() + return {} provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ .filter( @@ -451,7 +451,7 @@ class ProviderManager: if not provider_records: provider_records = [] - provider_quota_to_provider_record_dict = dict() + provider_quota_to_provider_record_dict = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue @@ -661,7 +661,7 @@ class ProviderManager: provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) # Convert provider_records to dict - quota_type_to_provider_records_dict = dict() + quota_type_to_provider_records_dict = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 1a5d3d11df..7f7c46e2dd 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -197,7 +197,7 @@ class Jieba(BaseKeyword): chunk_indices_count[node_id] += 1 sorted_chunk_indices = sorted( - list(chunk_indices_count.keys()), + chunk_indices_count.keys(), key=lambda x: chunk_indices_count[x], reverse=True, ) diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 5de2a66e2d..92f24277c1 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -201,7 +201,7 @@ class ReactMultiDatasetRouter: tool_strings.append( f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") formatted_tools = "\n".join(tool_strings) - unique_tool_names = set(tool.name for tool in tools) + unique_tool_names = {tool.name for tool in tools} tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 761aecde94..f85a5ed472 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -105,15 +105,15 @@ class BingSearchTool(BuiltinTool): def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: - key = credentials.get('subscription_key', None) + key = credentials.get('subscription_key') if not key: raise Exception('subscription_key is required') - server_url = credentials.get('server_url', None) + server_url = credentials.get('server_url') if not server_url: server_url = self.url - query = tool_parameters.get('query', None) + query = tool_parameters.get('query') if not query: raise Exception('query is required') @@ -170,7 +170,7 @@ class BingSearchTool(BuiltinTool): if not server_url: server_url = self.url - query = tool_parameters.get('query', None) + query = tool_parameters.get('query') if not query: raise Exception('query is required') diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 7da2651099..749ec761c6 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -16,12 +16,12 @@ class BarChartTool(BuiltinTool): data = data.split(';') # if all data is int, convert to int - if all([i.isdigit() for i in data]): + if all(i.isdigit() for i in data): data = [int(i) for i in data] else: data = [float(i) for i in data] - axis = tool_parameters.get('x_axis', None) or None + axis = tool_parameters.get('x_axis') or None if axis: axis = axis.split(';') if len(axis) != len(data): diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 9bc36be857..608bd6623c 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -17,14 +17,14 @@ class LinearChartTool(BuiltinTool): return self.create_text_message('Please input data') data = data.split(';') - axis = tool_parameters.get('x_axis', None) or None + axis = tool_parameters.get('x_axis') or None if axis: axis = axis.split(';') if len(axis) != len(data): axis = None # if all data is int, convert to int - if all([i.isdigit() for i in data]): + if all(i.isdigit() for i in data): data = [int(i) for i in data] else: data = [float(i) for i in data] diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index cd5e9b5329..4c551229e9 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -16,10 +16,10 @@ class PieChartTool(BuiltinTool): if not data: return self.create_text_message('Please input data') data = data.split(';') - categories = tool_parameters.get('categories', None) or None + categories = tool_parameters.get('categories') or None # if all data is int, convert to int - if all([i.isdigit() for i in data]): + if all(i.isdigit() for i in data): data = [int(i) for i in data] else: data = [float(i) for i in data] diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index 028da946d1..efd11cedce 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -37,10 +37,10 @@ class GaodeRepositoriesTool(BuiltinTool): apikey=self.runtime.credentials.get('api_key'))) weatherInfo_data = weatherInfo_response.json() if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': - contents = list() + contents = [] if len(weatherInfo_data.get('forecasts')) > 0: for item in weatherInfo_data['forecasts'][0]['casts']: - content = dict() + content = {} content['date'] = item.get('date') content['week'] = item.get('week') content['dayweather'] = item.get('dayweather') diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index 8a006f885f..a2f1e07fd4 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -39,10 +39,10 @@ class GihubRepositoriesTool(BuiltinTool): f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") response_data = response.json() if response.status_code == 200 and isinstance(response_data.get('items'), list): - contents = list() + contents = [] if len(response_data.get('items')) > 0: for item in response_data.get('items'): - content = dict() + content = {} updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") content['owner'] = item['owner']['login'] content['name'] = item['name'] diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 0d0eaef25b..ac06688c18 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -26,11 +26,11 @@ class JinaReaderTool(BuiltinTool): if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') - target_selector = tool_parameters.get('target_selector', None) + target_selector = tool_parameters.get('target_selector') if target_selector is not None and target_selector != '': headers['X-Target-Selector'] = target_selector - wait_for_selector = tool_parameters.get('wait_for_selector', None) + wait_for_selector = tool_parameters.get('wait_for_selector') if wait_for_selector is not None and wait_for_selector != '': headers['X-Wait-For-Selector'] = wait_for_selector @@ -43,7 +43,7 @@ class JinaReaderTool(BuiltinTool): if tool_parameters.get('gather_all_images_at_the_end', False): headers['X-With-Images-Summary'] = 'true' - proxy_server = tool_parameters.get('proxy_server', None) + proxy_server = tool_parameters.get('proxy_server') if proxy_server is not None and proxy_server != '': headers['X-Proxy-Url'] = proxy_server diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index 3eda2c5a22..e6bc08147f 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -33,7 +33,7 @@ class JinaSearchTool(BuiltinTool): if tool_parameters.get('gather_all_images_at_the_end', False): headers['X-With-Images-Summary'] = 'true' - proxy_server = tool_parameters.get('proxy_server', None) + proxy_server = tool_parameters.get('proxy_server') if proxy_server is not None and proxy_server != '': headers['X-Proxy-Url'] = proxy_server diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index d019fe7134..dd780aeadc 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -94,7 +94,7 @@ class GoogleTool(BuiltinTool): google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") api_key = self.runtime.credentials['searchapi_api_key'] result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 1b8cfa7e30..81c67c51a9 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -72,11 +72,11 @@ class GoogleJobsTool(BuiltinTool): """ query = tool_parameters['query'] result_type = tool_parameters['result_type'] - is_remote = tool_parameters.get("is_remote", None) + is_remote = tool_parameters.get("is_remote") google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") ltype = 1 if is_remote else None diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index d592dc25aa..5d2657dddd 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -82,7 +82,7 @@ class GoogleNewsTool(BuiltinTool): google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") api_key = self.runtime.credentials['searchapi_api_key'] result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index 3e46916b9b..5d12553629 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -107,7 +107,7 @@ class SearXNGSearchTool(BuiltinTool): if not host: raise Exception('SearXNG api is required') - query = tool_parameters.get('query', None) + query = tool_parameters.get('query') if not query: return self.create_text_message('Please input query') diff --git a/api/core/tools/provider/builtin/websearch/tools/get_markdown.py b/api/core/tools/provider/builtin/websearch/tools/get_markdown.py index 92d7d1addc..043879deea 100644 --- a/api/core/tools/provider/builtin/websearch/tools/get_markdown.py +++ b/api/core/tools/provider/builtin/websearch/tools/get_markdown.py @@ -43,7 +43,7 @@ class GetMarkdownTool(BuiltinTool): Invoke the SerplyApi tool. """ url = tool_parameters["url"] - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") api_key = self.runtime.credentials["serply_api_key"] result = SerplyApi(api_key).run(url, location=location) diff --git a/api/core/tools/provider/builtin/websearch/tools/job_search.py b/api/core/tools/provider/builtin/websearch/tools/job_search.py index 347b4eb4c4..9128305922 100644 --- a/api/core/tools/provider/builtin/websearch/tools/job_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/job_search.py @@ -55,7 +55,7 @@ class SerplyApi: f"Employer: {job['employer']}", f"Location: {job['location']}", f"Link: {job['link']}", - f"""Highest: {", ".join([h for h in job["highlights"]])}""", + f"""Highest: {", ".join(list(job["highlights"]))}""", "---", ]) ) @@ -78,7 +78,7 @@ class JobSearchTool(BuiltinTool): query = tool_parameters["query"] gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") api_key = self.runtime.credentials["serply_api_key"] result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) diff --git a/api/core/tools/provider/builtin/websearch/tools/news_search.py b/api/core/tools/provider/builtin/websearch/tools/news_search.py index 886ea47765..e9c0744f05 100644 --- a/api/core/tools/provider/builtin/websearch/tools/news_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/news_search.py @@ -80,7 +80,7 @@ class NewsSearchTool(BuiltinTool): query = tool_parameters["query"] gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") api_key = self.runtime.credentials["serply_api_key"] result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) diff --git a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py index 19df455231..0030a03c06 100644 --- a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py @@ -83,7 +83,7 @@ class ScholarSearchTool(BuiltinTool): query = tool_parameters["query"] gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") - location = tool_parameters.get("location", None) + location = tool_parameters.get("location") api_key = self.runtime.credentials["serply_api_key"] result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index d076cb384f..47e33b70c9 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -38,7 +38,7 @@ class BuiltinToolProviderController(ToolProviderController): super().__init__(**{ 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None, + 'credentials_schema': provider_yaml.get('credentials_for_provider', None), }) def _get_builtin_tools(self) -> list[Tool]: diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 3464bacced..0448a5df0c 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -159,8 +159,8 @@ class ApiTool(Tool): for content_type in self.api_bundle.openapi['requestBody']['content']: headers['Content-Type'] = content_type body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] - required = body_schema['required'] if 'required' in body_schema else [] - properties = body_schema['properties'] if 'properties' in body_schema else {} + required = body_schema.get('required', []) + properties = body_schema.get('properties', {}) for name, property in properties.items(): if name in parameters: # convert type diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index e52981b2d1..1170e1b7a5 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -90,7 +90,7 @@ class DatasetRetrieverTool(Tool): """ invoke dataset retriever tool """ - query = tool_parameters.get('query', None) + query = tool_parameters.get('query') if not query: return self.create_text_message(text='please input query') diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index aa184176a1..9fcadbd391 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -209,7 +209,7 @@ class ToolManager: if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter_rule.options)) + options = [x.value for x in parameter_rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 40ae6c66d5..f711f7c9f3 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -21,10 +21,7 @@ class ApiBasedToolSchemaParser: extra_info = extra_info if extra_info is not None else {} # set description to extra_info - if 'description' in openapi['info']: - extra_info['description'] = openapi['info']['description'] - else: - extra_info['description'] = '' + extra_info['description'] = openapi['info'].get('description', '') if len(openapi['servers']) == 0: raise ToolProviderNotFoundError('No server found in the openapi yaml.') @@ -95,8 +92,8 @@ class ApiBasedToolSchemaParser: # parse body parameters if 'schema' in interface['operation']['requestBody']['content'][content_type]: body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] - required = body_schema['required'] if 'required' in body_schema else [] - properties = body_schema['properties'] if 'properties' in body_schema else {} + required = body_schema.get('required', []) + properties = body_schema.get('properties', {}) for name, property in properties.items(): tool = ToolParameter( name=name, @@ -105,14 +102,14 @@ class ApiBasedToolSchemaParser: zh_Hans=name ), human_description=I18nObject( - en_US=property['description'] if 'description' in property else '', - zh_Hans=property['description'] if 'description' in property else '' + en_US=property.get('description', ''), + zh_Hans=property.get('description', '') ), type=ToolParameter.ToolParameterType.STRING, required=name in required, form=ToolParameter.ToolParameterForm.LLM, - llm_description=property['description'] if 'description' in property else '', - default=property['default'] if 'default' in property else None, + llm_description=property.get('description', ''), + default=property.get('default', None), ) # check if there is a type @@ -149,7 +146,7 @@ class ApiBasedToolSchemaParser: server_url=server_url + interface['path'], method=interface['method'], summary=interface['operation']['description'] if 'description' in interface['operation'] else - interface['operation']['summary'] if 'summary' in interface['operation'] else None, + interface['operation'].get('summary', None), operation_id=interface['operation']['operationId'], parameters=parameters, author='', diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 4c69c6eddc..1e7eb129a7 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -283,7 +283,7 @@ def strip_control_characters(text): # [Cn]: Other, Not Assigned # [Co]: Other, Private Use # [Cs]: Other, Surrogate - control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) + control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} retained_chars = ['\t', '\n', '\r', '\f'] # Remove non-printing control characters diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index bb0ccb5fc3..386fa410aa 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -93,7 +93,7 @@ class ParameterExtractorNode(LLMNode): # fetch memory memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) - if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \ + if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ and node_data.reasoning_mode == 'function_call': # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( @@ -644,7 +644,7 @@ class ParameterExtractorNode(LLMNode): if not model_schema: raise ValueError("Model schema not found") - if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]): + if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 3f2889adbe..a5c7814a54 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -246,10 +246,7 @@ class NotionOAuth(OAuthDataSource): } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - if 'results' in response_json: - results = response_json['results'] - else: - results = [] + results = response_json.get('results', []) return results def notion_block_parent_page_id(self, access_token: str, block_id: str): @@ -293,8 +290,5 @@ class NotionOAuth(OAuthDataSource): } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - if 'results' in response_json: - results = response_json['results'] - else: - results = [] + results = response_json.get('results', []) return results diff --git a/api/pyproject.toml b/api/pyproject.toml index 4174749f24..7b15406570 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -14,9 +14,11 @@ line-length = 120 preview = true select = [ "B", # flake8-bugbear rules + "C4", # flake8-comprehensions "F", # pyflakes rules "I", # isort rules - "UP", # pyupgrade rules + "UP", # pyupgrade rules + "B035", # static-key-dict-comprehension "E101", # mixed-spaces-and-tabs "E111", # indentation-with-invalid-multiple "E112", # no-indented-block @@ -28,8 +30,13 @@ select = [ "RUF100", # unused-noqa "RUF101", # redirected-noqa "S506", # unsafe-yaml-load + "SIM116", # if-else-block-instead-of-dict-lookup + "SIM401", # if-else-block-instead-of-dict-get + "SIM910", # dict-get-with-none-default "W191", # tab-indentation "W605", # invalid-escape-sequence + "F601", # multi-value-repeated-key-literal + "F602", # multi-value-repeated-key-variable ] ignore = [ "F403", # undefined-local-with-import-star @@ -82,8 +89,8 @@ HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b" HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" MOCK_SWITCH = "true" CODE_MAX_STRING_LENGTH = "80000" -CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194" -CODE_EXECUTION_API_KEY="dify-sandbox" +CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" +CODE_EXECUTION_API_KEY = "dify-sandbox" FIRECRAWL_API_KEY = "fc-" [tool.poetry] @@ -114,11 +121,11 @@ cachetools = "~5.3.0" weaviate-client = "~3.21.0" mailchimp-transactional = "~1.0.50" scikit-learn = "1.2.2" -sentry-sdk = {version = "~1.39.2", extras = ["flask"]} +sentry-sdk = { version = "~1.39.2", extras = ["flask"] } sympy = "1.12" jieba = "0.42.1" celery = "~5.3.6" -redis = {version = "~5.0.3", extras = ["hiredis"]} +redis = { version = "~5.0.3", extras = ["hiredis"] } chardet = "~5.1.0" python-docx = "~1.1.0" pypdfium2 = "~4.17.0" @@ -138,7 +145,7 @@ googleapis-common-protos = "1.63.0" google-cloud-storage = "2.16.0" replicate = "~0.22.0" websocket-client = "~1.7.0" -dashscope = {version = "~1.17.0", extras = ["tokenizer"]} +dashscope = { version = "~1.17.0", extras = ["tokenizer"] } huggingface-hub = "~0.16.4" transformers = "~4.35.0" tokenizers = "~0.15.0" @@ -152,10 +159,10 @@ qdrant-client = "1.7.3" cohere = "~5.2.4" pyyaml = "~6.0.1" numpy = "~1.26.4" -unstructured = {version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"]} +unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } bs4 = "~0.0.1" markdown = "~3.5.1" -httpx = {version = "~0.27.0", extras = ["socks"]} +httpx = { version = "~0.27.0", extras = ["socks"] } matplotlib = "~3.8.2" yfinance = "~0.2.40" pydub = "~0.25.1" @@ -180,7 +187,7 @@ pgvector = "0.2.5" pymysql = "1.1.1" tidb-vector = "0.0.9" google-cloud-aiplatform = "1.49.0" -vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]} +vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } kaleido = "0.2.1" tencentcloud-sdk-python-hunyuan = "~3.0.1158" tcvectordb = "1.3.2" diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e8446da44c..38ef874af3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -696,7 +696,7 @@ class DocumentService: elif document_data["data_source"]["type"] == "notion_import": notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] exist_page_ids = [] - exist_document = dict() + exist_document = {} documents = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 6a155922b4..2c2c0efc7a 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -95,7 +95,7 @@ class RecommendedAppService: categories.add(recommended_app.category) # add category to categories - return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))} + return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} @classmethod def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index d76cd4c7ff..010d53389a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -514,8 +514,8 @@ class WorkflowConverter: prompt_rules = prompt_template_config['prompt_rules'] role_prefix = { - "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + "user": prompt_rules.get('human_prefix', 'Human'), + "assistant": prompt_rules.get('assistant_prefix', 'Assistant') } else: advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index a150be3c00..d7a6c1224f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -112,7 +112,7 @@ def test_execute_llm(setup_openai_mock): # Mock db.session.close() db.session.close = MagicMock() - node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node result = node.run(pool) @@ -229,7 +229,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): # Mock db.session.close() db.session.close = MagicMock() - node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) + node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) # execute node result = node.run(pool) diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 056c78441d..3379e8338d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -59,7 +59,7 @@ def get_mocked_fetch_model_config( provider_model_bundle=provider_model_bundle ) - return MagicMock(return_value=tuple([model_instance, model_config])) + return MagicMock(return_value=(model_instance, model_config)) @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_function_calling_parameter_extractor(setup_openai_mock): diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 7e32ecbbdb..6d6363610b 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages(): prompt_rules = prompt_template['prompt_rules'] full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( max_token_limit=2000, - human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', - ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") )} real_prompt = prompt_template['prompt_template'].format(full_inputs)