diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 49f58af12c..a43be5fdf2 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline: if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") - elif isinstance(e, InvokeError) or isinstance(e, ValueError): + elif isinstance(e, InvokeError | ValueError): err = e else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 39f867118b..d5fda73009 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -45,7 +45,7 @@ class BaichuanModel: parameters: dict[str, Any], tools: Optional[list[PromptMessageTool]] = None, ) -> dict[str, Any]: - if model in self._model_mapping.keys(): + if model in self._model_mapping: # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. # we need to rename it to res_format to get its value if parameters.get("res_format") == "json_object": @@ -94,7 +94,7 @@ class BaichuanModel: timeout: int, tools: Optional[list[PromptMessageTool]] = None, ) -> Union[Iterator, dict]: - if model in self._model_mapping.keys(): + if model in self._model_mapping: api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 274ff02095..307c15e1fd 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index ad5197a154..51b634c6cf 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 72c319d395..db0b2deaa5 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): break elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 09a7f53f28..110028a288 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index b2c837dee1..bc7531ee20 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ text = "" for item in message: - if isinstance(item, UserPromptMessage): - text += item.content - elif isinstance(item, SystemPromptMessage): - text += item.content - elif isinstance(item, AssistantPromptMessage): + if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage): text += item.content else: raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 484ac088db..498962bd0f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if copy_prompt_message.role == PromptMessageRole.USER: - new_prompt_messages.append(copy_prompt_message) - elif copy_prompt_message.role == PromptMessageRole.TOOL: + if ( + copy_prompt_message.role == PromptMessageRole.USER + or copy_prompt_message.role == PromptMessageRole.TOOL + ): new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 17e48b8fbe..dc6a7ec564 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -56,14 +56,7 @@ class KeywordsModeration(Moderation): ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: - for value in inputs.values(): - if self._check_keywords_in_value(keywords_list, value): - return True + return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) - return False - - def _check_keywords_in_value(self, keywords_list, value): - for keyword in keywords_list: - if keyword.lower() in value.lower(): - return True - return False + def _check_keywords_in_value(self, keywords_list, value) -> bool: + return any(keyword.lower() in value.lower() for keyword in keywords_list) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index d6156e479a..68fcdf32da 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -223,7 +223,7 @@ class OpsTraceManager: :return: """ # auth check - if tracing_provider not in provider_config_map.keys() and tracing_provider is not None: + if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: App = db.session.query(App).filter(App.id == app_id).first() diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 0c9d3b343d..54290eaa5d 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -127,27 +127,26 @@ class RelytVector(BaseVector): ) chunks_table_data = [] - with self.client.connect() as conn: - with conn.begin(): - for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding, - "document": document, - "metadata": metadata, - } - ) + with self.client.connect() as conn, conn.begin(): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(chunks_table).values(chunks_table_data)) return ids @@ -186,11 +185,10 @@ class RelytVector(BaseVector): ) try: - with self.client.connect() as conn: - with conn.begin(): - delete_condition = chunks_table.c.id.in_(ids) - conn.execute(chunks_table.delete().where(delete_condition)) - return True + with self.client.connect() as conn, conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + conn.execute(chunks_table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index ada0c5cf46..dbedc1d4e9 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -63,10 +63,7 @@ class TencentVector(BaseVector): def _has_collection(self) -> bool: collections = self._db.list_collections() - for collection in collections: - if collection.collection_name == self._collection_name: - return True - return False + return any(collection.collection_name == self._collection_name for collection in collections) def _create_collection(self, dimension: int) -> None: lock_name = "vector_indexing_lock_{}".format(self._collection_name) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index e1ac9d596c..7eaf189292 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -124,20 +124,19 @@ class TiDBVector(BaseVector): texts = [d.page_content for d in documents] chunks_table_data = [] - with self._engine.connect() as conn: - with conn.begin(): - for id, text, meta, embedding in zip(ids, texts, metas, embeddings): - chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) + with self._engine.connect() as conn, conn.begin(): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): + chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(table).values(chunks_table_data)) return ids def text_exists(self, id: str) -> bool: @@ -160,11 +159,10 @@ class TiDBVector(BaseVector): raise ValueError("No ids provided to delete.") table = self._table(self._dimension) try: - with self._engine.connect() as conn: - with conn.begin(): - delete_condition = table.c.id.in_(ids) - conn.execute(table.delete().where(delete_condition)) - return True + with self._engine.connect() as conn, conn.begin(): + delete_condition = table.c.id.in_(ids) + conn.execute(table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 2db00d161b..c6f15e55b6 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor): raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path - self.temp_file = tempfile.NamedTemporaryFile() + # TODO: use a better way to handle the file + self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 self.temp_file.write(r.content) self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4375079ee5..16d6b879a4 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -120,8 +120,8 @@ class WeightRerankRunner: intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4948ec6ba8..e4ad78ed2b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -581,8 +581,8 @@ class DatasetRetrieval: intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 592fa230cf..71f8356ab8 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool): elif value.startswith('[{"organizeId"'): value = json.loads(value) value = "、".join([item["organizeName"] for item in value]) - elif value.startswith('[{"file_id"'): - value = "" - elif value == "[]": + elif value.startswith('[{"file_id"') or value == "[]": value = "" elif hasattr(value, "accountId"): value = value["fullname"] diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index fe105f70a7..9ca14b327c 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool): models_data=[], headers=headers, params=params, - recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True, + recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), ) result_str = "" diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index cac59f76d8..d8ca20bde6 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -39,7 +39,7 @@ class QRCodeGeneratorTool(BuiltinTool): # get error_correction error_correction = tool_parameters.get("error_correction", "") - if error_correction not in self.error_correction_levels.keys(): + if error_correction not in self.error_correction_levels: return self.create_text_message("Invalid parameter error_correction") try: diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index d632304a46..6d88d74635 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -44,36 +44,36 @@ class SearchAPI: @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + if "answer_box" in res and "answer" in res["answer_box"]: toret += res["answer_box"]["answer"] + "\n" - if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + if "answer_box" in res and "snippet" in res["answer_box"]: toret += res["answer_box"]["snippet"] + "\n" - if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): + if "knowledge_graph" in res and "description" in res["knowledge_graph"]: toret += res["knowledge_graph"]["description"] + "\n" - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys(): - if "title" in res["answer_box"]["organic_result"].keys(): + if "answer_box" in res and "organic_result" in res["answer_box"]: + if "title" in res["answer_box"]["organic_result"]: toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n" - elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys(): + elif "organic_results" in res and "link" in res["organic_results"][0]: toret = "" for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys(): + elif "related_questions" in res and "link" in res["related_questions"][0]: toret = "" for item in res["related_questions"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys(): + elif "related_searches" in res and "link" in res["related_searches"][0]: toret = "" for item in res["related_searches"]: toret += f"[{item['title']}]({item['link']})\n" diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 1544061c08..d29cb0ae3f 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -44,12 +44,12 @@ class SearchAPI: @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): + if "jobs" in res and "title" in res["jobs"][0]: for item in res["jobs"]: toret += ( "title: " @@ -65,7 +65,7 @@ class SearchAPI: toret = "No good search result found" elif type == "link": - if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys(): + if "jobs" in res and "apply_link" in res["jobs"][0]: for item in res["jobs"]: toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n" else: diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 95a7aad736..8458c8c958 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -44,25 +44,25 @@ class SearchAPI: @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" - if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + if "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys(): + if "organic_results" in res and "title" in res["organic_results"][0]: for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + elif "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += f"[{item['title']}]({item['link']})\n" else: diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 88def504fc..d7bfb53bd7 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -44,11 +44,11 @@ class SearchAPI: @staticmethod def _process_response(res: dict) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" - if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys(): + if "transcripts" in res and "text" in res["transcripts"][0]: for item in res["transcripts"]: toret += item["text"] + " " if toret == "": diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 12b6cc3352..9f415ceb55 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -35,7 +35,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): if model in ["sd3", "sd3-turbo"]: payload["model"] = tool_parameters.get("model") - if not model == "sd3-turbo": + if model != "sd3-turbo": payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 46137886bd..344f916494 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -206,10 +206,9 @@ class StableDiffusionTool(BuiltinTool): # Convert image to RGB and save as PNG try: - with Image.open(io.BytesIO(image_binary)) as image: - with io.BytesIO() as buffer: - image.convert("RGB").save(buffer, format="PNG") - image_binary = buffer.getvalue() + with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() except Exception as e: return self.create_text_message(f"Failed to process the image: {str(e)}") diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 67efcf0954..cb88e9519a 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -27,7 +27,7 @@ class WikipediaAPIWrapper: self.doc_content_chars_max = doc_content_chars_max def run(self, query: str, lang: str = "") -> str: - if lang in wikipedia.languages().keys(): + if lang in wikipedia.languages(): self.lang = lang wikipedia.set_lang(self.lang) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index c4983ebc65..1109ed7df2 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -19,9 +19,7 @@ class ToolFileMessageTransformer: result = [] for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: + if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: result.append(message) elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index f1f677b8c1..c156dd8c98 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -224,9 +224,7 @@ class Graph(BaseModel): """ leaf_node_ids = [] for node_id in self.node_ids: - if node_id not in self.edge_mapping: - leaf_node_ids.append(node_id) - elif ( + if node_id not in self.edge_mapping or ( len(self.edge_mapping[node_id]) == 1 and self.edge_mapping[node_id][0].target_node_id == self.root_node_id ): diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 06050e1549..e31a1479a8 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -24,7 +24,7 @@ class AnswerStreamGeneratorRouter: # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} for answer_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value: + if node_config.get("data", {}).get("type") != NodeType.ANSWER.value: continue # get generate route for stream output diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 30ce8fe018..a38d982393 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -17,7 +17,7 @@ class EndStreamGeneratorRouter: # parse stream output node value selector of end nodes end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} for end_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get("data", {}).get("type") == NodeType.END.value: + if node_config.get("data", {}).get("type") != NodeType.END.value: continue # skip end node in parallel diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 28fbf789fd..9d222b10b9 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -20,7 +20,7 @@ class ToolEntity(BaseModel): if not isinstance(value, dict): raise ValueError("tool_configurations must be a dictionary") - for key in values.data.get("tool_configurations", {}).keys(): + for key in values.data.get("tool_configurations", {}): value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): raise ValueError(f"{key} must be a string") diff --git a/api/pyproject.toml b/api/pyproject.toml index 23e2b5c549..3d100ebc58 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -17,14 +17,12 @@ select = [ "F", # pyflakes rules "I", # isort rules "N", # pep8-naming - "UP", # pyupgrade rules "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa "S506", # unsafe-yaml-load - "SIM116", # if-else-block-instead-of-dict-lookup - "SIM401", # if-else-block-instead-of-dict-get - "SIM910", # dict-get-with-none-default + "SIM", # flake8-simplify rules + "UP", # pyupgrade rules "W191", # tab-indentation "W605", # invalid-escape-sequence ] @@ -50,6 +48,15 @@ ignore = [ "B905", # zip-without-explicit-strict "N806", # non-lowercase-variable-in-function "N815", # mixed-case-variable-in-class-scope + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # eumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "SIM300", # yoda-conditions ] [tool.ruff.lint.per-file-ignores] diff --git a/api/services/file_service.py b/api/services/file_service.py index 5780abb2be..bedec76334 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -56,9 +56,7 @@ class FileService: if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS ) - if extension.lower() not in allowed_extensions: - raise UnsupportedFileTypeError() - elif only_image and extension.lower() not in IMAGE_EXTENSIONS: + if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() # read file content diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1e7935d299..d8e2b1689a 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -54,7 +54,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider: + if tracing_provider not in provider_config_map and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} config_class, other_keys = ( @@ -113,7 +113,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys(): + if tracing_provider not in provider_config_map: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists