From 40fb4d16ef0f5e8dd984c205b9941df68a445b0f Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Thu, 12 Sep 2024 15:50:49 +0800 Subject: [PATCH] chore: refurbish Python code by applying refurb linter rules (#8296) --- api/controllers/console/admin.py | 24 +++++++------------ api/controllers/console/app/audio.py | 8 ++----- api/controllers/console/auth/oauth.py | 2 +- api/controllers/console/datasets/datasets.py | 7 +----- api/controllers/console/explore/audio.py | 8 ++----- .../console/workspace/tool_providers.py | 2 +- api/controllers/service_api/app/audio.py | 8 ++----- api/controllers/web/audio.py | 8 ++----- api/core/agent/cot_agent_runner.py | 2 +- api/core/agent/fc_agent_runner.py | 2 +- api/core/app/apps/base_app_runner.py | 8 +++---- api/core/extension/extensible.py | 4 ++-- api/core/memory/token_buffer_memory.py | 2 +- .../__base/large_language_model.py | 2 +- .../model_providers/anthropic/llm/llm.py | 2 +- .../azure_ai_studio/llm/llm.py | 2 +- .../model_providers/azure_openai/llm/llm.py | 10 ++++---- .../model_providers/azure_openai/tts/tts.py | 2 +- .../model_providers/bedrock/llm/llm.py | 8 +++---- .../model_providers/chatglm/llm/llm.py | 4 ++-- .../model_providers/localai/llm/llm.py | 6 ++--- .../model_providers/minimax/llm/llm.py | 4 ++-- .../ollama/text_embedding/text_embedding.py | 2 +- .../model_providers/openai/llm/llm.py | 8 +++---- .../model_providers/openai/tts/tts.py | 2 +- .../openai_api_compatible/llm/llm.py | 4 ++-- .../model_providers/openllm/llm/llm.py | 4 ++-- .../openllm/llm/openllm_generate.py | 2 +- .../model_providers/replicate/llm/llm.py | 2 +- .../sagemaker/rerank/rerank.py | 3 ++- .../model_providers/sagemaker/tts/tts.py | 2 +- .../model_providers/spark/llm/llm.py | 2 +- .../tencent/speech2text/flash_recognizer.py | 3 ++- .../model_providers/tongyi/llm/llm.py | 4 ++-- .../model_providers/upstage/llm/llm.py | 6 ++--- .../model_providers/vertex_ai/llm/llm.py | 4 ++-- .../legacy/volc_sdk/base/auth.py | 3 ++- .../legacy/volc_sdk/base/util.py | 3 ++- .../volcengine_maas/llm/llm.py | 8 +++---- .../model_providers/wenxin/llm/llm.py | 6 ++--- .../wenxin/text_embedding/text_embedding.py | 2 +- .../model_providers/xinference/llm/llm.py | 6 ++--- .../model_providers/xinference/tts/tts.py | 2 +- .../model_providers/zhipuai/llm/llm.py | 4 ++-- .../zhipuai/zhipuai_sdk/core/_http_client.py | 4 +++- api/core/ops/langfuse_trace/langfuse_trace.py | 22 ++++++++--------- .../ops/langsmith_trace/langsmith_trace.py | 10 ++++---- api/core/ops/ops_trace_manager.py | 15 ++++++------ api/core/prompt/simple_prompt_transform.py | 6 ++--- .../rag/datasource/keyword/keyword_base.py | 2 +- .../vdb/analyticdb/analyticdb_vector.py | 4 ++-- .../datasource/vdb/chroma/chroma_vector.py | 2 +- .../vdb/elasticsearch/elasticsearch_vector.py | 6 ++--- .../datasource/vdb/milvus/milvus_vector.py | 2 +- .../datasource/vdb/myscale/myscale_vector.py | 2 +- .../vdb/opensearch/opensearch_vector.py | 2 +- .../rag/datasource/vdb/oracle/oraclevector.py | 4 ++-- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 2 +- .../rag/datasource/vdb/pgvector/pgvector.py | 2 +- .../datasource/vdb/qdrant/qdrant_vector.py | 2 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 2 +- .../datasource/vdb/tencent/tencent_vector.py | 2 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 2 +- api/core/rag/datasource/vdb/vector_base.py | 2 +- api/core/rag/datasource/vdb/vector_factory.py | 2 +- .../vdb/weaviate/weaviate_vector.py | 2 +- api/core/rag/extractor/blob/blob.py | 8 +++---- api/core/rag/extractor/extract_processor.py | 7 +++--- api/core/rag/extractor/helpers.py | 4 ++-- api/core/rag/extractor/markdown_extractor.py | 7 +++--- api/core/rag/extractor/text_extractor.py | 7 +++--- api/core/rag/extractor/word_extractor.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 10 ++++---- api/core/tools/provider/api_tool_provider.py | 2 +- .../aws/tools/sagemaker_text_rerank.py | 3 ++- .../builtin/hap/tools/get_worksheet_fields.py | 2 +- .../hap/tools/get_worksheet_pivot_data.py | 2 +- .../hap/tools/list_worksheet_records.py | 2 +- .../searchapi/tools/youtube_transcripts.py | 2 +- .../dataset_multi_retriever_tool.py | 6 ++--- .../dataset_retriever_tool.py | 6 ++--- api/core/tools/utils/web_reader_tool.py | 6 ++--- api/core/tools/utils/yaml_utils.py | 2 +- .../workflow/graph_engine/entities/graph.py | 5 ++-- api/core/workflow/nodes/code/code_node.py | 8 +++---- .../workflow/nodes/if_else/if_else_node.py | 2 +- api/core/workflow/nodes/llm/llm_node.py | 2 +- .../question_classifier_node.py | 2 +- ...aset_join_when_app_model_config_updated.py | 3 +-- ...oin_when_app_published_workflow_updated.py | 3 +-- api/extensions/storage/local_storage.py | 8 +++---- api/models/dataset.py | 4 ++-- api/models/model.py | 10 ++++---- api/pyproject.toml | 3 +++ api/services/account_service.py | 2 +- api/services/app_dsl_service.py | 12 ++++------ api/services/dataset_service.py | 8 ++----- api/services/hit_testing_service.py | 6 ++--- api/services/model_provider_service.py | 6 ++--- api/services/recommended_app_service.py | 8 +++---- api/services/tag_service.py | 2 +- .../tools/builtin_tools_manage_service.py | 4 ++-- api/services/website_service.py | 4 ++-- api/services/workflow/workflow_converter.py | 8 +++---- .../model_runtime/__mock/huggingface_tei.py | 2 +- 105 files changed, 220 insertions(+), 276 deletions(-) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a4ceec2662..f78ea9b288 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource): site = app.site if not site: - desc = args["desc"] if args["desc"] else "" - copy_right = args["copyright"] if args["copyright"] else "" - privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" - custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" + desc = args["desc"] or "" + copy_right = args["copyright"] or "" + privacy_policy = args["privacy_policy"] or "" + custom_disclaimer = args["custom_disclaimer"] or "" else: - desc = site.description if site.description else args["desc"] if args["desc"] else "" - copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" - privacy_policy = ( - site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" - ) - custom_disclaimer = ( - site.custom_disclaimer - if site.custom_disclaimer - else args["custom_disclaimer"] - if args["custom_disclaimer"] - else "" - ) + desc = site.description or args["desc"] or "" + copy_right = site.copyright or args["copyright"] or "" + privacy_policy = site.privacy_policy or args["privacy_policy"] or "" + custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 437a6a7b38..7332758e83 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource): and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ae1b49f3ec..1df0f5de9d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: # Create account - account_name = user_info.name if user_info.name else "Dify" + account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 6ccacc78ee..08603bfc21 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return { - "api_base_url": ( - dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") - ) - + "/v1" - } + return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} class DatasetRetrievalSettingApi(Resource): diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 71cb060ecc..2eb7e04490 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource): and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c41a898fdc..d2a17b133b 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource): return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args["provider_name"] if args["provider_name"] else "", + args["provider_name"] or "", args["tool_name"], args["credentials"], args["parameters"], diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 85aab047a7..8d8ca8d78c 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -84,14 +84,10 @@ class TextApi(Resource): and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts( diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index d062d2893b..49c467dbe1 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -83,14 +83,10 @@ class TextApi(WebApiResource): and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") - voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = ( - args.get("voice") - if args.get("voice") - else app_model.app_model_config.text_to_speech_dict.get("voice") - ) + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 29b428a7c3..ebe04bf260 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), system_fingerprint="", ) ), diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 27cf561e3d..13164e0bfc 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), system_fingerprint="", ) ), diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index aadb43ad39..bd2be18bfd 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -161,7 +161,7 @@ class AppRunner: app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, - query=query if query else "", + query=query or "", files=files, context=context, memory=memory, @@ -189,7 +189,7 @@ class AppRunner: prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs=inputs, - query=query if query else "", + query=query or "", files=files, context=context, memory_config=memory_config, @@ -238,7 +238,7 @@ class AppRunner: model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage(), + usage=usage or LLMUsage.empty_usage(), ), ), PublishFrom.APPLICATION_MANAGER, @@ -351,7 +351,7 @@ class AppRunner: tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query if query else "", + query=query or "", message_id=message_id, trace_manager=app_generate_entity.trace_manager, ) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index f1a49c4921..97dbaf2026 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -3,6 +3,7 @@ import importlib.util import json import logging import os +from pathlib import Path from typing import Any, Optional from pydantic import BaseModel @@ -63,8 +64,7 @@ class Extensible: builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): - with open(builtin_file_path, encoding="utf-8") as f: - position = int(f.read().strip()) + position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position_map[extension_name] = position if (extension_name + ".py") not in file_names: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 54b1d8212b..a14d237a12 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -39,7 +39,7 @@ class TokenBufferMemory: ) if message_limit and message_limit > 0: - message_limit = message_limit if message_limit <= 500 else 500 + message_limit = min(message_limit, 500) else: message_limit = 500 diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index e8789ec7df..ba88cc1f38 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -449,7 +449,7 @@ if you are not sure about the structure. model=real_model, prompt_messages=prompt_messages, message=prompt_message, - usage=usage if usage else LLMUsage.empty_usage(), + usage=usage or LLMUsage.empty_usage(), system_fingerprint=system_fingerprint, ), credentials=credentials, diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 0cb66842e7..ff741e0240 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else "" + chunk_text = chunk.delta.text or "" full_assistant_content += chunk_text # transform assistant message to prompt message diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py index 42eae6c1e5..516ef8b295 100644 --- a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py @@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel): model=real_model, prompt_messages=prompt_messages, message=prompt_message, - usage=usage if usage else LLMUsage.empty_usage(), + usage=usage or LLMUsage.empty_usage(), system_fingerprint=system_fingerprint, ), credentials=credentials, diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 3b9fb52e24..f0033ea051 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): continue # transform assistant message to prompt message - text = delta.text if delta.text else "" + text = delta.text or "" assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else "" + full_assistant_content += delta.delta.content or "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else "" + completion += delta.delta.content or "" yield LLMResultChunk( model=real_model, diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index bbad726467..8db044b24d 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): ) for i in range(len(sentences)) ] - for index, future in enumerate(futures): + for future in futures: yield from future.result().__enter__().iter_bytes(1024) else: diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 239ae52b4c..953ac0741f 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -331,10 +331,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): elif "contentBlockDelta" in chunk: delta = chunk["contentBlockDelta"]["delta"] if "text" in delta: - chunk_text = delta["text"] if delta["text"] else "" + chunk_text = delta["text"] or "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else "", + content=chunk_text or "", ) index = chunk["contentBlockDelta"]["contentBlockIndex"] yield LLMResultChunk( @@ -751,7 +751,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else "") + completion_tokens = self.get_num_tokens(model, credentials, output or "") else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") @@ -828,7 +828,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=content_delta if content_delta else "", + content=content_delta or "", ) index += 1 diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 114acc1ec3..b3eeb48e22 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 94c03efe7b..e7295355f6 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) + assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage @@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel): if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 76ed704a75..4250c40cfb 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): index=0, message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) elif message.function_call: @@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage(content=message.content, tool_calls=[]), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 2cfb79b241..b4c61d8a6d 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -65,7 +65,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): inputs = [] used_tokens = 0 - for i, text in enumerate(texts): + for text in texts: # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 578687b5d3..c4b381df41 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -508,7 +508,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): continue # transform assistant message to prompt message - text = delta.text if delta.text else "" + text = delta.text or "" assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -760,11 +760,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): final_tool_calls.extend(tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else "" + full_assistant_content += delta.delta.content or "" if has_finish_reason: final_chunk = LLMResultChunk( diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index bfb443698c..b50b43199f 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -88,7 +88,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): ) for i in range(len(sentences)) ] - for index, future in enumerate(futures): + for future in futures: yield from future.result().__enter__().iter_bytes(1024) else: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 41ca163a92..5a8a754f72 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -179,9 +179,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): features = [] function_calling_type = credentials.get("function_calling_type", "no_call") - if function_calling_type in ["function_call"]: + if function_calling_type == "function_call": features.append(ModelFeature.TOOL_CALL) - elif function_calling_type in ["tool_call"]: + elif function_calling_type == "tool_call": features.append(ModelFeature.MULTI_TOOL_CALL) stream_function_calling = credentials.get("stream_function_calling", "supported") diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index b560afca39..34b4de7962 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -179,7 +179,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): index=0, message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) else: @@ -189,7 +189,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage(content=message.content, tool_calls=[]), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index e754479ec0..351dcced15 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -106,7 +106,7 @@ class OpenLLMGenerate: timeout = 120 data = { - "stop": stop if stop else [], + "stop": stop or [], "prompt": "\n".join([message.content for message in prompt_messages]), "llm_config": default_llm_config, } diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 87c8bc4a91..e77372cae2 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -214,7 +214,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): index += 1 - assistant_prompt_message = AssistantPromptMessage(content=output if output else "") + assistant_prompt_message = AssistantPromptMessage(content=output or "") if index < prediction_output_length: yield LLMResultChunk( diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 7e7614055c..959dff6a21 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional import boto3 @@ -94,7 +95,7 @@ class SageMakerRerankModel(RerankModel): for idx in range(len(scores)): candidate_docs.append({"content": docs[idx], "score": scores[idx]}) - sorted(candidate_docs, key=lambda x: x["score"], reverse=True) + sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 3 rerank_documents = [] diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 3dd5f8f64c..a22bd6dd6e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -260,7 +260,7 @@ class SageMakerText2SpeechModel(TTSModel): for payload in payloads ] - for index, future in enumerate(futures): + for future in futures: resp = future.result() audio_bytes = requests.get(resp.get("s3_presign_url")).content for i in range(0, len(audio_bytes), 1024): diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 0c42acf5aa..57193dc031 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -220,7 +220,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): delta = content assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else "", + content=delta or "", ) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py index 9fd4a45f45..c3c21793e8 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py @@ -1,6 +1,7 @@ import base64 import hashlib import hmac +import operator import time import requests @@ -127,7 +128,7 @@ class FlashRecognizer: return s def _build_req_with_signature(self, secret_key, params, header): - query = sorted(params.items(), key=lambda d: d[0]) + query = sorted(params.items(), key=operator.itemgetter(0)) signstr = self._format_sign_string(query) signature = self._sign(signstr, secret_key) header["Authorization"] = signature diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index db0b2deaa5..ca708d931c 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -4,6 +4,7 @@ import tempfile import uuid from collections.abc import Generator from http import HTTPStatus +from pathlib import Path from typing import Optional, Union, cast from dashscope import Generation, MultiModalConversation, get_tokenizer @@ -454,8 +455,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}") - with open(file_path, "wb") as image_file: - image_file.write(base64.b64decode(encoded_string)) + Path(file_path).write_bytes(base64.b64decode(encoded_string)) return f"file://{file_path}" diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index 9646e209b2..74524e81e2 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -368,11 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): final_tool_calls.extend(tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else "" + full_assistant_content += delta.delta.content or "" if has_finish_reason: final_chunk = LLMResultChunk( diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 1b9931d2c3..dad5002f35 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -231,10 +231,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else "" + chunk_text = chunk.delta.text or "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else "", + content=chunk_text or "", ) index = chunk.index yield LLMResultChunk( diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 7435720252..97c77de8d3 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -1,5 +1,6 @@ # coding : utf-8 import datetime +from itertools import starmap import pytz @@ -48,7 +49,7 @@ class SignResult: self.authorization = "" def __str__(self): - return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()]) + return "\n".join(list(starmap("{}:{}".format, self.__dict__.items()))) class Credentials: diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py index 44f9959965..178d63714e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py @@ -1,5 +1,6 @@ import hashlib import hmac +import operator from functools import reduce from urllib.parse import quote @@ -40,4 +41,4 @@ class Util: if len(hv) == 1: hv = "0" + hv lst.append(hv) - return reduce(lambda x, y: x + y, lst) + return reduce(operator.add, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index c25851fc45..f8bf8fb821 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -174,9 +174,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content=message["content"] if message["content"] else "", tool_calls=[] - ), + message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]), usage=usage, finish_reason=choice.get("finish_reason"), ), @@ -208,7 +206,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message["content"] if message["content"] else "", + content=message["content"] or "", tool_calls=tool_calls, ), usage=self._calc_response_usage( @@ -284,7 +282,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage( - content=message.content if message.content else "", + content=message.content or "", tool_calls=tool_calls, ), usage=self._calc_response_usage( diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 8feedbfe55..ec3556f7da 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -199,7 +199,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): secret_key=credentials["secret_key"], ) - user = user if user else "ErnieBotDefault" + user = user or "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ @@ -289,7 +289,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): index=0, message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) else: @@ -299,7 +299,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage(content=message.content, tool_calls=[]), - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index db323ae4c1..4d6f6dccd0 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -85,7 +85,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): api_key = credentials["api_key"] secret_key = credentials["secret_key"] embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) - user = user if user else "ErnieBotDefault" + user = user or "ErnieBotDefault" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 7ad236880b..f8f6c6b12d 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -589,7 +589,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # convert tool call to assistant message tool call tool_calls = assistant_message.tool_calls - assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) + assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or []) function_call = assistant_message.function_call if function_call: assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] @@ -652,7 +652,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: @@ -749,7 +749,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) + assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index 60db151302..d29e8ed8f9 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -215,7 +215,7 @@ class XinferenceText2SpeechModel(TTSModel): for i in range(len(sentences)) ] - for index, future in enumerate(futures): + for future in futures: response = future.result() for i in range(0, len(response), 1024): yield response[i : i + 1024] diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 29b873fd06..f76e51fee9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -414,10 +414,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls + content=delta.delta.content or "", tool_calls=assistant_tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else "" + full_assistant_content += delta.delta.content or "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 48eeb37c41..5f7f6d04f2 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: return {key: val for key, val in merged.items() if val is not None} +from itertools import starmap + from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) @@ -159,7 +161,7 @@ class HttpClient: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) + items = flatten(list(starmap(self._object_to_formdata, data.items()))) serialized: dict[str, object] = {} for key, value in items: diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a0f3ac7f86..6aefbec9aa 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id + trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") if trace_info.message_id: trace_id = trace_info.message_id @@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( - id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), + id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, @@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=trace_info.end_time, metadata=trace_info.metadata, level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", ) self.add_span(langfuse_span_data=workflow_span_data) else: @@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance): else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} - created_at = node_execution.created_at if node_execution.created_at else datetime.now() + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=finished_at, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", - parent_observation_id=( - trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id - ), + status_message=trace_info.error or "", + parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), ) else: span_data = LangfuseSpan( @@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=finished_at, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", ) self.add_span(langfuse_span_data=span_data) @@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance): output=outputs, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", usage=generation_usage, ) @@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance): output=message_data.answer, metadata=metadata, level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), - status_message=message_data.error if message_data.error else "", + status_message=message_data.error or "", usage=generation_usage, ) @@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance): end_time=trace_info.end_time, metadata=trace_info.metadata, level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), - status_message=message_data.error if message_data.error else "", + status_message=message_data.error or "", usage=generation_usage, ) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index eea7bb3535..37cbea13fd 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance): langsmith_run = LangSmithRunModel( file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, - id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, inputs=trace_info.workflow_run_inputs, run_type=LangSmithRunType.tool, @@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance): }, error=trace_info.error, tags=["workflow"], - parent_run_id=trace_info.message_id if trace_info.message_id else None, + parent_run_id=trace_info.message_id or None, ) self.add_run(langsmith_run) @@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance): else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} - created_at = node_execution.created_at if node_execution.created_at else datetime.now() + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance): extra={ "metadata": metadata, }, - parent_run_id=trace_info.workflow_app_log_id - if trace_info.workflow_app_log_id - else trace_info.workflow_run_id, + parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, tags=["node_execution"], ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 68fcdf32da..6f17bade97 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -354,11 +354,11 @@ class TraceTask: workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} workflow_run_version = workflow_run.version - error = workflow_run.error if workflow_run.error else "" + error = workflow_run.error or "" total_tokens = workflow_run.total_tokens - file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else [] + file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id @@ -452,7 +452,7 @@ class TraceTask: message_tokens=message_tokens, answer_tokens=message_data.answer_tokens, total_tokens=message_tokens + message_data.answer_tokens, - error=message_data.error if message_data.error else "", + error=message_data.error or "", inputs=inputs, outputs=message_data.answer, file_list=file_list, @@ -487,7 +487,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( - message_id=workflow_app_log_id if workflow_app_log_id else message_id, + message_id=workflow_app_log_id or message_id, inputs=inputs, message_data=message_data.to_dict(), flagged=moderation_result.flagged, @@ -527,7 +527,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( - message_id=workflow_app_log_id if workflow_app_log_id else message_id, + message_id=workflow_app_log_id or message_id, message_data=message_data.to_dict(), inputs=message_data.message, outputs=message_data.answer, @@ -569,7 +569,7 @@ class TraceTask: dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, - inputs=message_data.query if message_data.query else message_data.inputs, + inputs=message_data.query or message_data.inputs, documents=[doc.model_dump() for doc in documents], start_time=timer.get("start"), end_time=timer.get("end"), @@ -695,8 +695,7 @@ class TraceQueueManager: self.start_timer() def add_trace_task(self, trace_task: TraceTask): - global trace_manager_timer - global trace_manager_queue + global trace_manager_timer, trace_manager_queue try: if self.trace_instance: trace_task.app_id = self.app_id diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 13e5c5253e..7479560520 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -112,11 +112,11 @@ class SimplePromptTransform(PromptTransform): for v in prompt_template_config["special_variable_keys"]: # support #context#, #query# and #histories# if v == "#context#": - variables["#context#"] = context if context else "" + variables["#context#"] = context or "" elif v == "#query#": - variables["#query#"] = query if query else "" + variables["#query#"] = query or "" elif v == "#histories#": - variables["#histories#"] = histories if histories else "" + variables["#histories#"] = histories or "" prompt_template = prompt_template_config["prompt_template"] prompt = prompt_template.format(variables) diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index 27e4f383ad..4b9ec460e6 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -34,7 +34,7 @@ class BaseKeyword(ABC): raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: + for text in texts.copy(): doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index a9c0eefb78..ef48d22963 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index cb38cf94a9..41f749279d 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -92,7 +92,7 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) ids: list[str] = results["ids"][0] documents: list[str] = results["documents"][0] diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f13723b51f..c52ca5f488 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector): id=uuids[i], document={ Field.CONTENT_KEY.value: documents[i].page_content, - Field.VECTOR.value: embeddings[i] if embeddings[i] else None, - Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}, + Field.VECTOR.value: embeddings[i] or None, + Field.METADATA_KEY.value: documents[i].metadata or {}, }, ) self._client.indices.refresh(index=self._collection_name) @@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if score > score_threshold: doc.metadata["score"] = score docs.append(doc) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index d6d7136282..7d78f54256 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -141,7 +141,7 @@ class MilvusVector(BaseVector): for result in results[0]: metadata = result["entity"].get(Field.METADATA_KEY.value) metadata["score"] = result["distance"] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if result["distance"] > score_threshold: doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 90464ac42a..fbd7864edd 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -122,7 +122,7 @@ class MyScaleVector(BaseVector): def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold") or 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) where_str = ( f"WHERE dist < {1 - score_threshold}" if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 7c0f620956..a6f70c8733 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector): metadata = {} metadata["score"] = hit["_score"] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if hit["_score"] > score_threshold: doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index d223b0decf..1636582362 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -200,7 +200,7 @@ class OracleVector(BaseVector): [numpy.array(query_vector)], ) docs = [] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) for record in cur: metadata, text, distance = record score = 1 - distance @@ -212,7 +212,7 @@ class OracleVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if len(query) > 0: # Check which language the query is in zh_pattern = re.compile("[\u4e00-\u9fa5]+") diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 24b391d63a..765436006b 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -198,7 +198,7 @@ class PGVectoRS(BaseVector): metadata = record.meta score = 1 - dis metadata["score"] = score - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if score > score_threshold: doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index d2d9e5238b..c35464b464 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -144,7 +144,7 @@ class PGVector(BaseVector): (json.dumps(query_vector),), ) docs = [] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) for record in cur: metadata, text, distance = record score = 1 - distance diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 83d561819c..2bf417e993 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -339,7 +339,7 @@ class QdrantVector(BaseVector): for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 54290eaa5d..76f214ffa4 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -230,7 +230,7 @@ class RelytVector(BaseVector): # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) if 1 - score > score_threshold: docs.append(document) return docs diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index dbedc1d4e9..b550743f39 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -153,7 +153,7 @@ class TencentVector(BaseVector): limit=kwargs.get("top_k", 4), timeout=self._client_config.timeout, ) - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 7eaf189292..a2eff6ff04 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -185,7 +185,7 @@ class TiDBVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) filter = kwargs.get("filter") distance = 1 - score_threshold diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index fb80cdec87..1a0dc7f48b 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -49,7 +49,7 @@ class BaseVector(ABC): raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: + for text in texts.copy(): doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index bb24143a41..ca90233b7f 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -153,7 +153,7 @@ class Vector: return CacheEmbedding(embedding_model) def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: + for text in texts.copy(): doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index ca1123c6a0..64e92c5b82 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -205,7 +205,7 @@ class WeaviateVector(BaseVector): docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 + score_threshold = kwargs.get("score_threshold", 0.0) # check score threshold if score > score_threshold: doc.metadata["score"] = score diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index f4c7b4b5f7..e46ab8b7fd 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -12,7 +12,7 @@ import mimetypes from abc import ABC, abstractmethod from collections.abc import Generator, Iterable, Mapping from io import BufferedReader, BytesIO -from pathlib import PurePath +from pathlib import Path, PurePath from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, model_validator @@ -56,8 +56,7 @@ class Blob(BaseModel): def as_string(self) -> str: """Read data as a string.""" if self.data is None and self.path: - with open(str(self.path), encoding=self.encoding) as f: - return f.read() + return Path(str(self.path)).read_text(encoding=self.encoding) elif isinstance(self.data, bytes): return self.data.decode(self.encoding) elif isinstance(self.data, str): @@ -72,8 +71,7 @@ class Blob(BaseModel): elif isinstance(self.data, str): return self.data.encode(self.encoding) elif self.data is None and self.path: - with open(str(self.path), "rb") as f: - return f.read() + return Path(str(self.path)).read_bytes() else: raise ValueError(f"Unable to get bytes for blob {self}") diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 244ef9614a..3181656f59 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -68,8 +68,7 @@ class ExtractProcessor: suffix = "." + re.search(r"\.(\w+)$", filename).group(1) file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, "wb") as file: - file.write(response.content) + Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: delimiter = "\n" @@ -111,7 +110,7 @@ class ExtractProcessor: ) elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in [".docx"]: + elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) @@ -143,7 +142,7 @@ class ExtractProcessor: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in [".htm", ".html"]: extractor = HtmlExtractor(file_path) - elif file_extension in [".docx"]: + elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 9a21d4272a..69ca9d5d63 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,6 +1,7 @@ """Document loader helpers.""" import concurrent.futures +from pathlib import Path from typing import NamedTuple, Optional, cast @@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding import chardet def read_and_detect(file_path: str) -> list[dict]: - with open(file_path, "rb") as f: - rawdata = f.read() + rawdata = Path(file_path).read_bytes() return cast(list[dict], chardet.detect_all(rawdata)) with concurrent.futures.ThreadPoolExecutor() as executor: diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index ca125ecf55..849852ac23 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,6 +1,7 @@ """Abstract interface for document loader implementations.""" import re +from pathlib import Path from typing import Optional, cast from core.rag.extractor.extractor_base import BaseExtractor @@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor): """Parse file into tuples.""" content = "" try: - with open(filepath, encoding=self._encoding) as f: - content = f.read() + content = Path(filepath).read_text(encoding=self._encoding) except UnicodeDecodeError as e: if self._autodetect_encoding: detected_encodings = detect_file_encodings(filepath) for encoding in detected_encodings: try: - with open(filepath, encoding=encoding.encoding) as f: - content = f.read() + content = Path(filepath).read_text(encoding=encoding.encoding) break except UnicodeDecodeError: continue diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index ed0ae41f51..b2b51d71d7 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,5 +1,6 @@ """Abstract interface for document loader implementations.""" +from pathlib import Path from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor @@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor): """Load from file path.""" text = "" try: - with open(self._file_path, encoding=self._encoding) as f: - text = f.read() + text = Path(self._file_path).read_text(encoding=self._encoding) except UnicodeDecodeError as e: if self._autodetect_encoding: detected_encodings = detect_file_encodings(self._file_path) for encoding in detected_encodings: try: - with open(self._file_path, encoding=encoding.encoding) as f: - text = f.read() + text = Path(self._file_path).read_text(encoding=encoding.encoding) break except UnicodeDecodeError: continue diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 947644a90b..7352ef378b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor): if col_index >= total_cols: break cell_content = self._parse_cell(cell, image_map).strip() - cell_colspan = cell.grid_span if cell.grid_span else 1 + cell_colspan = cell.grid_span or 1 for i in range(cell_colspan): if col_index + i < total_cols: row_cells[col_index + i] = cell_content if i == 0 else "" diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index e4ad78ed2b..cdaca8387d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -256,7 +256,7 @@ class DatasetRetrieval: # get retrieval model config dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k top_k = retrieval_model_config["top_k"] @@ -410,7 +410,7 @@ class DatasetRetrieval: return [] # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query @@ -433,9 +433,7 @@ class DatasetRetrieval: reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") - if retrieval_model.get("reranking_mode") - else "reranking_model", + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), ) @@ -486,7 +484,7 @@ class DatasetRetrieval: } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k top_k = retrieval_model_config["top_k"] diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index 2e6018cffc..d99314e33a 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -106,7 +106,7 @@ class ApiToolProviderController(ToolProviderController): "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, "llm": tool_bundle.summary or "", }, - "parameters": tool_bundle.parameters if tool_bundle.parameters else [], + "parameters": tool_bundle.parameters or [], } ) diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 3c35b65e66..bffcd058b5 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -1,4 +1,5 @@ import json +import operator from typing import Any, Union import boto3 @@ -71,7 +72,7 @@ class SageMakerReRankTool(BuiltinTool): candidate_docs[idx]["score"] = scores[idx] line = 8 - sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True) + sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 9e0918afa9..40e1af043b 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -115,7 +115,7 @@ class GetWorksheetFieldsTool(BuiltinTool): fields.append(field) fields_list.append( f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}" - f"|{field['options'] if field['options'] else ''}|" + f"|{field['options'] or ''}|" ) fields.append( diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 6b831f3145..01c1af9b3e 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -130,7 +130,7 @@ class GetWorksheetPivotDataTool(BuiltinTool): # ] rows = [] for row in data["data"]: - row_data = row["rows"] if row["rows"] else {} + row_data = row["rows"] or {} row_data.update(row["columns"]) row_data.update(row["values"]) rows.append(row_data) diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 5888f7443f..171895a306 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -113,7 +113,7 @@ class ListWorksheetRecordsTool(BuiltinTool): result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." if result["total"] > 0: result_text += ( - f" The following are {result['total'] if result['total'] < limit else limit}" + f" The following are {min(limit, result['total'])}" f" pieces of data presented in a table format:\n\n{table_header}" ) for row in rows: diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index d7bfb53bd7..09725cf8a2 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -37,7 +37,7 @@ class SearchAPI: return { "engine": "youtube_transcripts", "video_id": video_id, - "lang": language if language else "en", + "lang": language or "en", **{key: value for key, value in kwargs.items() if value not in [None, ""]}, } diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index e76af6fe70..067600c601 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -160,7 +160,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): hit_callback.on_query(query, dataset.id) # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query @@ -183,9 +183,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") - if retrieval_model.get("reranking_mode") - else "reranking_model", + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index f61458278e..ad533946a1 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -55,7 +55,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): hit_callback.on_query(query, dataset.id) # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -76,9 +76,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") - if retrieval_model.get("reranking_mode") - else "reranking_model", + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), ) else: diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index fc2f63a241..e57cae9f16 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -8,6 +8,7 @@ import subprocess import tempfile import unicodedata from contextlib import contextmanager +from pathlib import Path from urllib.parse import unquote import chardet @@ -98,7 +99,7 @@ def get_url(url: str, user_agent: str = None) -> str: authors=a["byline"], publish_date=a["date"], top_image="", - text=a["plain_text"] if a["plain_text"] else "", + text=a["plain_text"] or "", ) return res @@ -117,8 +118,7 @@ def extract_using_readabilipy(html): subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) # Read output of call to Readability.parse() from JSON file and return as Python dictionary - with open(article_json_path, encoding="utf-8") as json_file: - input_json = json.loads(json_file.read()) + input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) # Deleting files after processing os.unlink(article_json_path) diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index bcb061376d..99b9f80499 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -21,7 +21,7 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any with open(file_path, encoding="utf-8") as yaml_file: try: yaml_content = yaml.safe_load(yaml_file) - return yaml_content if yaml_content else default_value + return yaml_content or default_value except Exception as e: raise YAMLError(f"Failed to load YAML file {file_path}: {e}") except Exception as e: diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index c156dd8c98..20efe91a59 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -268,7 +268,7 @@ class Graph(BaseModel): f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." ) - new_route = route[:] + new_route = route.copy() new_route.append(graph_edge.target_node_id) cls._check_connected_to_previous_node( route=new_route, @@ -679,8 +679,7 @@ class Graph(BaseModel): all_routes_node_ids = set() parallel_start_node_ids: dict[str, list[str]] = {} for branch_node_id, node_ids in routes_node_ids.items(): - for node_id in node_ids: - all_routes_node_ids.add(node_id) + all_routes_node_ids.update(node_ids) if branch_node_id in reverse_edge_mapping: for graph_edge in reverse_edge_mapping[branch_node_id]: diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 73164fff9a..9da7ad99f3 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -74,7 +74,7 @@ class CodeNode(BaseNode): :return: """ if not isinstance(value, str): - if isinstance(value, type(None)): + if value is None: return None else: raise ValueError(f"Output variable `{variable}` must be a string") @@ -95,7 +95,7 @@ class CodeNode(BaseNode): :return: """ if not isinstance(value, int | float): - if isinstance(value, type(None)): + if value is None: return None else: raise ValueError(f"Output variable `{variable}` must be a number") @@ -182,7 +182,7 @@ class CodeNode(BaseNode): f"Output {prefix}.{output_name} is not a valid array." f" make sure all elements are of the same type." ) - elif isinstance(output_value, type(None)): + elif output_value is None: pass else: raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") @@ -284,7 +284,7 @@ class CodeNode(BaseNode): for i, value in enumerate(result[output_name]): if not isinstance(value, dict): - if isinstance(value, type(None)): + if value is None: pass else: raise ValueError( diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 5b4737c6e5..37384202d8 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -79,7 +79,7 @@ class IfElseNode(BaseNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, process_data=process_datas, - edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' + edge_source_handle=selected_case_id or "false", # Use case ID or 'default' outputs=outputs, ) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 049c211488..3d336b0b0b 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -580,7 +580,7 @@ class LLMNode(BaseNode): prompt_messages = prompt_transform.get_prompt( prompt_template=node_data.prompt_template, inputs=inputs, - query=query if query else "", + query=query or "", files=files, context=context, memory_config=node_data.memory, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index d860f848ec..2ae58bc5f7 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -250,7 +250,7 @@ class QuestionClassifierNode(LLMNode): for class_ in classes: category = {"category_id": class_.id, "category_name": class_.name} categories.append(category) - instruction = node_data.instruction if node_data.instruction else "" + instruction = node_data.instruction or "" input_text = query memory_str = "" if memory: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 59375b1a0b..de7c0f4dfe 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -18,8 +18,7 @@ def handle(sender, **kwargs): added_dataset_ids = dataset_ids else: old_dataset_ids = set() - for app_dataset_join in app_dataset_joins: - old_dataset_ids.add(app_dataset_join.dataset_id) + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 333b85ecb2..c5e98e263f 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -22,8 +22,7 @@ def handle(sender, **kwargs): added_dataset_ids = dataset_ids else: old_dataset_ids = set() - for app_dataset_join in app_dataset_joins: - old_dataset_ids.add(app_dataset_join.dataset_id) + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py index 46ee4bf80f..f833ae85dc 100644 --- a/api/extensions/storage/local_storage.py +++ b/api/extensions/storage/local_storage.py @@ -1,6 +1,7 @@ import os import shutil from collections.abc import Generator +from pathlib import Path from flask import Flask @@ -26,8 +27,7 @@ class LocalStorage(BaseStorage): folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) - with open(os.path.join(os.getcwd(), filename), "wb") as f: - f.write(data) + Path(os.path.join(os.getcwd(), filename)).write_bytes(data) def load_once(self, filename: str) -> bytes: if not self.folder or self.folder.endswith("/"): @@ -38,9 +38,7 @@ class LocalStorage(BaseStorage): if not os.path.exists(filename): raise FileNotFoundError("File not found") - with open(filename, "rb") as f: - data = f.read() - + data = Path(filename).read_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/models/dataset.py b/api/models/dataset.py index 55f6ed3180..0da35910cd 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -144,7 +144,7 @@ class Dataset(db.Model): "top_k": 2, "score_threshold_enabled": False, } - return self.retrieval_model if self.retrieval_model else default_retrieval_model + return self.retrieval_model or default_retrieval_model @property def tags(self): @@ -160,7 +160,7 @@ class Dataset(db.Model): .all() ) - return tags if tags else [] + return tags or [] @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: diff --git a/api/models/model.py b/api/models/model.py index 8ab3026522..a8b2e00ee4 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -118,7 +118,7 @@ class App(db.Model): @property def api_base_url(self): - return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1" + return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property def tenant(self): @@ -207,7 +207,7 @@ class App(db.Model): .all() ) - return tags if tags else [] + return tags or [] class AppModelConfig(db.Model): @@ -908,7 +908,7 @@ class Message(db.Model): "id": message_file.id, "type": message_file.type, "url": url, - "belongs_to": message_file.belongs_to if message_file.belongs_to else "user", + "belongs_to": message_file.belongs_to or "user", } ) @@ -1212,7 +1212,7 @@ class Site(db.Model): @property def app_base_url(self): - return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/") + return dify_config.APP_WEB_URL or request.url_root.rstrip("/") class ApiToken(db.Model): @@ -1488,7 +1488,7 @@ class TraceAppConfig(db.Model): @property def tracing_config_dict(self): - return self.tracing_config if self.tracing_config else {} + return self.tracing_config or {} @property def tracing_config_str(self): diff --git a/api/pyproject.toml b/api/pyproject.toml index 616794cf3a..7d64d07678 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -15,6 +15,7 @@ select = [ "C4", # flake8-comprehensions "E", # pycodestyle E rules "F", # pyflakes rules + "FURB", # refurb rules "I", # isort rules "N", # pep8-naming "RUF019", # unnecessary-key-check @@ -37,6 +38,8 @@ ignore = [ "F405", # undefined-local-with-import-star-usage "F821", # undefined-name "F841", # unused-variable + "FURB113", # repeated-append + "FURB152", # math-constant "UP007", # non-pep604-annotation "UP032", # f-string "B005", # strip-with-multi-characters diff --git a/api/services/account_service.py b/api/services/account_service.py index 7fb42f9e81..e839ae54ba 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -544,7 +544,7 @@ class RegisterService: """Register account""" try: account = AccountService.create_account( - email=email, name=name, interface_language=language if language else languages[0], password=password + email=email, name=name, interface_language=language or languages[0], password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 73c446b83b..8dcefbadc0 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -81,13 +81,11 @@ class AppDslService: raise ValueError("Missing app in data argument") # get app basic info - name = args.get("name") if args.get("name") else app_data.get("name") - description = args.get("description") if args.get("description") else app_data.get("description", "") - icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") - icon = args.get("icon") if args.get("icon") else app_data.get("icon") - icon_background = ( - args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") - ) + name = args.get("name") or app_data.get("name") + description = args.get("description") or app_data.get("description", "") + icon_type = args.get("icon_type") or app_data.get("icon_type") + icon = args.get("icon") or app_data.get("icon") + icon_background = args.get("icon_background") or app_data.get("icon_background") use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) # import dsl and create app diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cce0874cf4..a52b5b7e7d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -155,7 +155,7 @@ class DatasetService: dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None - dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME + dataset.permission = permission or DatasetPermissionEnum.ONLY_ME db.session.add(dataset) db.session.commit() return dataset @@ -681,11 +681,7 @@ class DocumentService: "score_threshold_enabled": False, } - dataset.retrieval_model = ( - document_data.get("retrieval_model") - if document_data.get("retrieval_model") - else default_retrieval_model - ) + dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model documents = [] batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 2f911f5036..a9f963dbac 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -33,7 +33,7 @@ class HitTestingService: # get retrieval model , if the model is not setting , using default if not retrieval_model: - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model all_documents = RetrievalService.retrieve( retrieval_method=retrieval_model.get("search_method", "semantic_search"), @@ -46,9 +46,7 @@ class HitTestingService: reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") - if retrieval_model.get("reranking_mode") - else "reranking_model", + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), ) diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index c0f3c40762..384a072b37 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,6 +1,7 @@ import logging import mimetypes import os +from pathlib import Path from typing import Optional, cast import requests @@ -453,9 +454,8 @@ class ModelProviderService: mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, "rb") as f: - byte_data = f.read() - return byte_data, mimetype + byte_data = Path(file_path).read_bytes() + return byte_data, mimetype def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: """ diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 10abf0a764..daec8393d0 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,6 +1,7 @@ import json import logging from os import path +from pathlib import Path from typing import Optional import requests @@ -218,10 +219,9 @@ class RecommendedAppService: return cls.builtin_data root_path = current_app.root_path - with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: - json_data = f.read() - data = json.loads(json_data) - cls.builtin_data = data + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") + ) return cls.builtin_data diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 0c17485a9f..5e2851cd8f 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -57,7 +57,7 @@ class TagService: .all() ) - return tags if tags else [] + return tags or [] @staticmethod def save_tags(args: dict) -> Tag: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index dc8cebb587..e2e49d017e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,5 +1,6 @@ import json import logging +from pathlib import Path from configs import dify_config from core.helper.position_helper import is_filtered @@ -183,8 +184,7 @@ class BuiltinToolManageService: get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, "rb") as f: - icon_bytes = f.read() + icon_bytes = Path(icon_path).read_bytes() return icon_bytes, mime_type diff --git a/api/services/website_service.py b/api/services/website_service.py index 6dff35d63f..fea605cf30 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -50,8 +50,8 @@ class WebsiteService: excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { "crawlerOptions": { - "includes": includes if includes else [], - "excludes": excludes if excludes else [], + "includes": includes or [], + "excludes": excludes or [], "generateImgAltText": True, "limit": options.get("limit", 1), "returnOnlyUrls": False, diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 4b845be2f4..db1a036e68 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -63,11 +63,11 @@ class WorkflowConverter: # create new app new_app = App() new_app.tenant_id = app_model.tenant_id - new_app.name = name if name else app_model.name + "(workflow)" + new_app.name = name or app_model.name + "(workflow)" new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value - new_app.icon_type = icon_type if icon_type else app_model.icon_type - new_app.icon = icon if icon else app_model.icon - new_app.icon_background = icon_background if icon_background else app_model.icon_background + new_app.icon_type = icon_type or app_model.icon_type + new_app.icon = icon or app_model.icon + new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site new_app.enable_api = app_model.enable_api new_app.api_rpm = app_model.api_rpm diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index 83317e59de..b9a721c803 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -51,7 +51,7 @@ class MockTEIClass: # } # } embeddings = [] - for idx, text in enumerate(texts): + for idx in range(len(texts)): embedding = [0.1] * 768 embeddings.append( {