diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 3f5e1adca2..35ac42a14c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource): def post(self, resource_id): resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() current_key_count = ( diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 46c0b22993..df7bd352af 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -20,7 +20,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation @@ -36,8 +36,8 @@ class CompletionConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) @@ -143,8 +143,8 @@ class ChatConversationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("keyword", type=str, location="args") - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument( "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" ) diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 81826a20d0..4806b02b55 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode @@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index db2f683589..942271a634 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode from models.workflow import WorkflowRunTriggeredFrom @@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") - parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8ba6b53e7e..f3198dfc1d 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -8,7 +8,7 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from libs.helper import email, str_len, timezone +from libs.helper import StrLen, email, timezone from libs.password import hash_password, valid_password from models.account import AccountStatus from services.account_service import RegisterService @@ -37,7 +37,7 @@ class ActivateApi(Resource): parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument( "interface_language", type=supported_language, required=True, nullable=False, location="json" diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 7d3ae677ee..ae759bb752 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -4,7 +4,7 @@ from flask import session from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import str_len +from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -28,7 +28,7 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument("password", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=StrLen(30), required=True, location="json") input_password = parser.parse_args()["password"] if input_password != os.environ.get("INIT_PASSWORD"): diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 827695e00f..46b4ef5d87 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask import request from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import email, get_remote_ip, str_len +from libs.helper import StrLen, email, get_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService @@ -40,7 +40,7 @@ class SetupApi(Resource): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") - parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("name", type=StrLen(30), required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 1277dcebc5..88e1256ed5 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom @@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ) runner.run() - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index d9fc599542..18b115dfe4 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -21,7 +21,7 @@ class AudioTrunk: self.status = status -def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice ) future_queue.put(futures_result) break @@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher: self.MAX_SENTENCE += 1 text_content = "".join(sentence_arr) futures_result = self.executor.submit( - _invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice + _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice ) future_queue.put(futures_result) if text_tmp: @@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher: break future_queue.put(None) - def checkAndGetAudio(self) -> AudioTrunk | None: + def check_and_get_audio(self) -> AudioTrunk | None: try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 90f547b0f2..c4cdba6441 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,7 +19,7 @@ from core.app.entities.queue_entities import ( QueueStopEvent, QueueTextChunkEvent, ) -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.variable_pool import VariablePool @@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query=query, message_id=message_id, ) - except ModerationException as e: + except ModerationError as e: self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8f65a670c3..94206a1b1c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc try: if not tts_publisher: break - audio_trunk = tts_publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 7ba6bbab94..abf8a332ab 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom @@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 6b676b0353..45b1bf0093 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,7 +15,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db from models.model import App, Conversation, Message, MessageAgentThought @@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner): query=query, message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index df972756d5..f3c3199354 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -171,5 +171,5 @@ class AppQueueManager: ) -class GenerateTaskStoppedException(Exception): +class GenerateTaskStoppedError(Exception): pass diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 15c7140308..032556ec4c 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter @@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index bd90586825..425f1ab7ef 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Conversation, Message @@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner): query=query, message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index d7301224e8..7fce296f2b 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter @@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): queue_manager=queue_manager, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index da49c8701f..908d74ff53 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Message @@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner): query=query, message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index a91d48d246..f629c5c8b7 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -8,7 +8,7 @@ from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 7f259db6eb..363c3c82bb 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index c685008577..57a77591a0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -12,7 +12,7 @@ from pydantic import ValidationError import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator): ) runner.run() - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator): return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index c9f501cd5e..76371f800b 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 215d02bddd..93edf8e0e8 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa try: if not tts_publisher: break - audio_trunk = tts_publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index b71924b2d3..b26b3c8291 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -15,6 +15,7 @@ class Segment(BaseModel): value: Any @field_validator("value_type") + @classmethod def validate_value_type(cls, value): """ This validator checks if the provided value is equal to the default value of the 'value_type' field. diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 49f58af12c..a43be5fdf2 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline: if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") - elif isinstance(e, InvokeError) or isinstance(e, ValueError): + elif isinstance(e, InvokeError | ValueError): err = e else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 61e920845c..659503301e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) @@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id) + audio_response = self._listen_audio_msg(publisher, task_id) if audio_response: yield audio_response else: @@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: if publisher is None: break - audio = publisher.checkAndGetAudio() + audio = publisher.check_and_get_audio() if audio is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 50cde18c54..6d5393ce5c 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -67,7 +67,7 @@ class DatasetIndexToolCallbackHandler: data_source_type=item.get("data_source_type"), segment_id=item.get("segment_id"), score=item.get("score") if "score" in item else None, - hit_count=item.get("hit_count") if "hit_count" else None, + hit_count=item.get("hit_count") if "hit_count" in item else None, word_count=item.get("word_count") if "word_count" in item else None, segment_position=item.get("segment_position") if "segment_position" in item else None, index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4a80a3ffe9..7ee6e63817 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) -class CodeExecutionException(Exception): +class CodeExecutionError(Exception): pass @@ -86,15 +86,15 @@ class CodeExecutor: ), ) if response.status_code == 503: - raise CodeExecutionException("Code execution service is unavailable") + raise CodeExecutionError("Code execution service is unavailable") elif response.status_code != 200: raise Exception( f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" ) - except CodeExecutionException as e: + except CodeExecutionError as e: raise e except Exception as e: - raise CodeExecutionException( + raise CodeExecutionError( "Failed to execute code, which is likely a network issue," " please check if the sandbox service is running." f" ( Error: {str(e)} )" @@ -103,15 +103,15 @@ class CodeExecutor: try: response = response.json() except: - raise CodeExecutionException("Failed to parse response") + raise CodeExecutionError("Failed to parse response") if (code := response.get("code")) != 0: - raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") response = CodeExecutionResponse(**response) if response.data.error: - raise CodeExecutionException(response.data.error) + raise CodeExecutionError(response.data.error) return response.data.stdout or "" @@ -126,13 +126,13 @@ class CodeExecutor: """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f"Unsupported language {language}") + raise CodeExecutionError(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) try: response = cls.execute_code(language, preload, runner) - except CodeExecutionException as e: + except CodeExecutionError as e: raise e return template_transformer.transform_response(response) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b6968e46cd..eeb1dbfda0 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -78,8 +78,8 @@ class IndexingRunner: dataset_document=dataset_document, documents=documents, ) - except DocumentIsPausedException: - raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) @@ -134,8 +134,8 @@ class IndexingRunner: self._load( index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) @@ -192,8 +192,8 @@ class IndexingRunner: self._load( index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) @@ -756,7 +756,7 @@ class IndexingRunner: indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: - raise DocumentIsPausedException() + raise DocumentIsPausedError() @staticmethod def _update_document_index_status( @@ -767,10 +767,10 @@ class IndexingRunner: """ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: - raise DocumentIsPausedException() + raise DocumentIsPausedError() document = DatasetDocument.query.filter_by(id=document_id).first() if not document: - raise DocumentIsDeletedPausedException() + raise DocumentIsDeletedPausedError() update_params = {DatasetDocument.indexing_status: after_indexing_status} @@ -875,9 +875,9 @@ class IndexingRunner: pass -class DocumentIsPausedException(Exception): +class DocumentIsPausedError(Exception): pass -class DocumentIsDeletedPausedException(Exception): +class DocumentIsDeletedPausedError(Exception): pass diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 6a60f8de80..1e743f1757 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserException(Exception): +class OutputParserError(Exception): pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index b6932698cb..0c7683b16d 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,6 +1,6 @@ from typing import Any -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import ( RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser: raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") + raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 2ad2289869..3b9fb52e24 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -420,7 +420,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ), ) - index += 0 + index += 1 # calculate num tokens prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 6e181ac5f8..d5fda73009 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -7,7 +7,7 @@ from requests import post from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -45,7 +45,7 @@ class BaichuanModel: parameters: dict[str, Any], tools: Optional[list[PromptMessageTool]] = None, ) -> dict[str, Any]: - if model in self._model_mapping.keys(): + if model in self._model_mapping: # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. # we need to rename it to res_format to get its value if parameters.get("res_format") == "json_object": @@ -94,7 +94,7 @@ class BaichuanModel: timeout: int, tools: Optional[list[PromptMessageTool]] = None, ) -> Union[Iterator, dict]: - if model in self._model_mapping.keys(): + if model in self._model_mapping: api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") @@ -124,7 +124,7 @@ class BaichuanModel: if err == "invalid_api_key": raise InvalidAPIKeyError(msg) elif err == "insufficient_quota": - raise InsufficientAccountBalance(msg) + raise InsufficientAccountBalanceError(msg) elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) elif err == "invalid_request_error": diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 4e56e58d7e..309b5cf413 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -10,7 +10,7 @@ class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): +class InsufficientAccountBalanceError(Exception): pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 3291fe2b2e..91a14bf100 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel): InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index b7276fabb5..779dfbb608 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): if err == "invalid_api_key": raise InvalidAPIKeyError(msg) elif err == "insufficient_quota": - raise InsufficientAccountBalance(msg) + raise InsufficientAccountBalanceError(msg) elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) elif err and "rate" in err: @@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml index 53657c08a9..c2d5eb6471 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.00025' output: '0.00125' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml index d083d31e30..f90fa04266 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.015' output: '0.075' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml index 5302231086..dad0d6b6b6 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml index 6995d2bf56..962def8011 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml index 1a3239c85e..70294e4ad3 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml index 0343e3bbec..0a8ea61b6d 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index a2a69b86bb..e07f2a419a 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -20,6 +20,7 @@ from botocore.exceptions import ( from PIL.Image import Image # local import +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -44,6 +45,14 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel logger = logging.getLogger(__name__) +ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" class BedrockLargeLanguageModel(LargeLanguageModel): @@ -70,6 +79,40 @@ class BedrockLargeLanguageModel(LargeLanguageModel): logger.info(f"current model id: {model_id} did not support by Converse API") return None + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: list[Callback] = None, + ) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if model_parameters.get("response_format"): + stop = stop or [] + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + response_format = model_parameters.pop("response_format") + format_prompt = SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) + ) + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = format_prompt + else: + prompt_messages.insert(0, format_prompt) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + def _invoke( self, model: str, diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 274ff02095..307c15e1fd 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index ad5197a154..51b634c6cf 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index ffa9ea2f90..2cfb79b241 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -77,7 +77,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): inputs.append(text) # Prepare the payload for the request - payload = {"input": inputs, "model": model, "options": {"use_mmap": "true"}} + payload = {"input": inputs, "model": model, "options": {"use_mmap": True}} # Make the request to the Ollama API response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 257dffa30d..1234e44f80 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import ( ) -class _CommonOAI_API_Compat: +class _CommonOaiApiCompat: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 75929af590..24317b488c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.utils import helper logger = logging.getLogger(__name__) -class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): +class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): """ Model class for OpenAI large language model. """ diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 2e8b4ddd72..405096578c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -6,10 +6,10 @@ import requests from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): +class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel): """ Model class for OpenAI Compatible Speech to text model. """ diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index ab358cf70a..e83cfdf873 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index d0522233e3..b62a2d2aaf 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 72c319d395..db0b2deaa5 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): break elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index ecb22e21bd..110028a288 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -115,7 +115,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): token = credentials.token # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region - if "opus" or "claude-3-5-sonnet" in model: + if "opus" in model or "claude-3-5-sonnet" in model: location = "us-east5" else: location = "us-central1" @@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py index 025b1ed6d2..266f1216f8 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService class MaaSClient(MaasService): @@ -106,7 +106,7 @@ class MaaSClient(MaasService): def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: try: resp = fn() - except MaasException as e: + except MaasError as e: raise wrap_error(e) return resp diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 8b9c346265..91dbe21a61 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -1,144 +1,144 @@ -from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError -class ClientSDKRequestError(MaasException): +class ClientSDKRequestError(MaasError): pass -class SignatureDoesNotMatch(MaasException): +class SignatureDoesNotMatchError(MaasError): pass -class RequestTimeout(MaasException): +class RequestTimeoutError(MaasError): pass -class ServiceConnectionTimeout(MaasException): +class ServiceConnectionTimeoutError(MaasError): pass -class MissingAuthenticationHeader(MaasException): +class MissingAuthenticationHeaderError(MaasError): pass -class AuthenticationHeaderIsInvalid(MaasException): +class AuthenticationHeaderIsInvalidError(MaasError): pass -class InternalServiceError(MaasException): +class InternalServiceError(MaasError): pass -class MissingParameter(MaasException): +class MissingParameterError(MaasError): pass -class InvalidParameter(MaasException): +class InvalidParameterError(MaasError): pass -class AuthenticationExpire(MaasException): +class AuthenticationExpireError(MaasError): pass -class EndpointIsInvalid(MaasException): +class EndpointIsInvalidError(MaasError): pass -class EndpointIsNotEnable(MaasException): +class EndpointIsNotEnableError(MaasError): pass -class ModelNotSupportStreamMode(MaasException): +class ModelNotSupportStreamModeError(MaasError): pass -class ReqTextExistRisk(MaasException): +class ReqTextExistRiskError(MaasError): pass -class RespTextExistRisk(MaasException): +class RespTextExistRiskError(MaasError): pass -class EndpointRateLimitExceeded(MaasException): +class EndpointRateLimitExceededError(MaasError): pass -class ServiceConnectionRefused(MaasException): +class ServiceConnectionRefusedError(MaasError): pass -class ServiceConnectionClosed(MaasException): +class ServiceConnectionClosedError(MaasError): pass -class UnauthorizedUserForEndpoint(MaasException): +class UnauthorizedUserForEndpointError(MaasError): pass -class InvalidEndpointWithNoURL(MaasException): +class InvalidEndpointWithNoURLError(MaasError): pass -class EndpointAccountRpmRateLimitExceeded(MaasException): +class EndpointAccountRpmRateLimitExceededError(MaasError): pass -class EndpointAccountTpmRateLimitExceeded(MaasException): +class EndpointAccountTpmRateLimitExceededError(MaasError): pass -class ServiceResourceWaitQueueFull(MaasException): +class ServiceResourceWaitQueueFullError(MaasError): pass -class EndpointIsPending(MaasException): +class EndpointIsPendingError(MaasError): pass -class ServiceNotOpen(MaasException): +class ServiceNotOpenError(MaasError): pass AuthErrors = { - "SignatureDoesNotMatch": SignatureDoesNotMatch, - "MissingAuthenticationHeader": MissingAuthenticationHeader, - "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid, - "AuthenticationExpire": AuthenticationExpire, - "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint, + "SignatureDoesNotMatch": SignatureDoesNotMatchError, + "MissingAuthenticationHeader": MissingAuthenticationHeaderError, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError, + "AuthenticationExpire": AuthenticationExpireError, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError, } BadRequestErrors = { - "MissingParameter": MissingParameter, - "InvalidParameter": InvalidParameter, - "EndpointIsInvalid": EndpointIsInvalid, - "EndpointIsNotEnable": EndpointIsNotEnable, - "ModelNotSupportStreamMode": ModelNotSupportStreamMode, - "ReqTextExistRisk": ReqTextExistRisk, - "RespTextExistRisk": RespTextExistRisk, - "InvalidEndpointWithNoURL": InvalidEndpointWithNoURL, - "ServiceNotOpen": ServiceNotOpen, + "MissingParameter": MissingParameterError, + "InvalidParameter": InvalidParameterError, + "EndpointIsInvalid": EndpointIsInvalidError, + "EndpointIsNotEnable": EndpointIsNotEnableError, + "ModelNotSupportStreamMode": ModelNotSupportStreamModeError, + "ReqTextExistRisk": ReqTextExistRiskError, + "RespTextExistRisk": RespTextExistRiskError, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError, + "ServiceNotOpen": ServiceNotOpenError, } RateLimitErrors = { - "EndpointRateLimitExceeded": EndpointRateLimitExceeded, - "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded, - "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded, + "EndpointRateLimitExceeded": EndpointRateLimitExceededError, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError, } ServerUnavailableErrors = { "InternalServiceError": InternalServiceError, - "EndpointIsPending": EndpointIsPending, - "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull, + "EndpointIsPending": EndpointIsPendingError, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError, } ConnectionErrors = { "ClientSDKRequestError": ClientSDKRequestError, - "RequestTimeout": RequestTimeout, - "ServiceConnectionTimeout": ServiceConnectionTimeout, - "ServiceConnectionRefused": ServiceConnectionRefused, - "ServiceConnectionClosed": ServiceConnectionClosed, + "RequestTimeout": RequestTimeoutError, + "ServiceConnectionTimeout": ServiceConnectionTimeoutError, + "ServiceConnectionRefused": ServiceConnectionRefusedError, + "ServiceConnectionClosed": ServiceConnectionClosedError, } ErrorCodeMap = { @@ -150,7 +150,7 @@ ErrorCodeMap = { } -def wrap_error(e: MaasException) -> Exception: +def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 53f320736b..8b3eb157be 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole -from .maas import MaasException, MaasService +from .maas import MaasError, MaasService -__all__ = ["MaasService", "ChatRole", "MaasException"] +__all__ = ["MaasService", "ChatRole", "MaasError"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 8f8139426c..7435720252 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -74,7 +74,7 @@ class Signer: def sign(request, credentials): if request.path == "": request.path = "/" - if request.method != "GET" and not ("Content-Type" in request.headers): + if request.method != "GET" and "Content-Type" not in request.headers: request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8" format_date = Signer.get_current_format_date() diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py index 096339b3c7..33c41f3eb3 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py @@ -31,7 +31,7 @@ class Service: self.service_info.scheme = scheme def get(self, api, params, doseq=0): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] @@ -49,7 +49,7 @@ class Service: raise Exception(resp.text) def post(self, api, params, form): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) @@ -72,7 +72,7 @@ class Service: raise Exception(resp.text) def json(self, api, params, body): - if not (api in self.api_info): + if api not in self.api_info: raise Exception("no such api") api_info = self.api_info[api] r = self.prepare_request(api_info, params) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py index 01f15aec24..a3836685f1 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -63,7 +63,7 @@ class MaasService(Service): raise if res.error is not None and res.error.code_n != 0: - raise MaasException( + raise MaasError( res.error.code_n, res.error.code, res.error.message, @@ -72,7 +72,7 @@ class MaasService(Service): yield res return iter_fn() - except MaasException: + except MaasError: raise except Exception as e: raise new_client_sdk_request_error(str(e)) @@ -94,7 +94,7 @@ class MaasService(Service): resp["req_id"] = req_id return resp - except MaasException as e: + except MaasError as e: raise e except Exception as e: raise new_client_sdk_request_error(str(e), req_id) @@ -109,7 +109,7 @@ class MaasService(Service): if not self._apikey and not credentials_exist: raise new_client_sdk_request_error("no valid credential", req_id) - if not (api in self.api_info): + if api not in self.api_info: raise new_client_sdk_request_error("no such api", req_id) def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): @@ -147,14 +147,14 @@ class MaasService(Service): raise new_client_sdk_request_error(raw, req_id) if resp.error: - raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id) else: raise new_client_sdk_request_error(resp, req_id) return res -class MaasException(Exception): +class MaasError(Exception): def __init__(self, code_n, code, message, req_id): self.code_n = code_n self.code = code @@ -172,7 +172,7 @@ class MaasException(Exception): def new_client_sdk_request_error(raw, req_id=""): - return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) class BinaryResponseContent: @@ -192,7 +192,7 @@ class BinaryResponseContent: if len(error_bytes) > 0: resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) def iter_bytes(self) -> Iterator[bytes]: yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 98409ab872..c25851fc45 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): }, [UserPromptMessage(content="ping\nAnswer: ")], ) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) @staticmethod diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 3cdcd2740c..9cba2cb879 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, - MaasException, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) @@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: self._invoke(model=model, credentials=credentials, texts=["ping"]) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) def _validate_credentials_v3(self, model: str, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py index f2e2248680..bd074e0477 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], InvokeBadRequestError: [BadRequestError, KeyError], @@ -42,7 +42,7 @@ class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): +class InsufficientAccountBalanceError(Exception): pass diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index b2c837dee1..bc7531ee20 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ text = "" for item in message: - if isinstance(item, UserPromptMessage): - text += item.content - elif isinstance(item, SystemPromptMessage): - text += item.content - elif isinstance(item, AssistantPromptMessage): + if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage): text += item.content else: raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 484ac088db..498962bd0f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if copy_prompt_message.role == PromptMessageRole.USER: - new_prompt_messages.append(copy_prompt_message) - elif copy_prompt_message.role == PromptMessageRole.TOOL: + if ( + copy_prompt_message.role == PromptMessageRole.USER + or copy_prompt_message.role == PromptMessageRole.TOOL + ): new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 4b91f20184..60898d5547 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -76,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): @@ -111,5 +111,5 @@ class Moderation(Extensible, ABC): raise ValueError("outputs_config.preset_response must be less than 100 characters") -class ModerationException(Exception): +class ModerationError(Exception): pass diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 336c16eecf..46d3963bd0 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -2,7 +2,7 @@ import logging from typing import Optional from core.app.app_config.entities import AppConfig -from core.moderation.base import ModerationAction, ModerationException +from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask @@ -61,7 +61,7 @@ class InputModeration: return False, inputs, query if moderation_result.action == ModerationAction.DIRECT_OUTPUT: - raise ModerationException(moderation_result.preset_response) + raise ModerationError(moderation_result.preset_response) elif moderation_result.action == ModerationAction.OVERRIDDEN: inputs = moderation_result.inputs query = moderation_result.query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 17e48b8fbe..dc6a7ec564 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -56,14 +56,7 @@ class KeywordsModeration(Moderation): ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: - for value in inputs.values(): - if self._check_keywords_in_value(keywords_list, value): - return True + return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) - return False - - def _check_keywords_in_value(self, keywords_list, value): - for keyword in keywords_list: - if keyword.lower() in value.lower(): - return True - return False + def _check_keywords_in_value(self, keywords_list, value) -> bool: + return any(keyword.lower() in value.lower() for keyword in keywords_list) diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 0ab2139a88..5c79867571 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig): host: str = "https://api.langfuse.com" @field_validator("host") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": v = "https://api.langfuse.com" @@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig): endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": v = "https://api.smith.langchain.com" diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a3ce27d5d4..f27a0af6e0 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel): metadata: dict[str, Any] @field_validator("inputs", "outputs") + @classmethod def ensure_type(cls, v): if v is None: return None diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index af7661f0af..447b799f1f 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -196,6 +198,7 @@ class GenerationUsage(BaseModel): totalCost: Optional[float] = None @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel): model_config = ConfigDict(protected_namespaces=()) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 8cbf162bf2..05c932fb99 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name values = info.data @@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): return v return v + @classmethod @field_validator("start_time", "end_time") def format_time(cls, v, info: ValidationInfo): if not isinstance(v, datetime): diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index d6156e479a..68fcdf32da 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -223,7 +223,7 @@ class OpsTraceManager: :return: """ # auth check - if tracing_provider not in provider_config_map.keys() and tracing_provider is not None: + if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: App = db.session.query(App).filter(App.id == app_id).first() diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 76c808f76e..f13723b51f 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel): password: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config HOST is required") diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 1d08046641..d6d7136282 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -28,6 +28,7 @@ class MilvusConfig(BaseModel): database: str = "default" @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values.get("uri"): raise ValueError("config MILVUS_URI is required") diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index ecd7e0271c..7c0f620956 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel): secure: bool = False @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index eb2e3e0a8c..06c20ceb5f 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index b778582e8a..24b391d63a 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index b01cd91e07..38dfd24b56 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index d8e4ff628c..54290eaa5d 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -34,6 +34,7 @@ class RelytConfig(BaseModel): database: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config RELYT_HOST is required") @@ -126,27 +127,26 @@ class RelytVector(BaseVector): ) chunks_table_data = [] - with self.client.connect() as conn: - with conn.begin(): - for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding, - "document": document, - "metadata": metadata, - } - ) + with self.client.connect() as conn, conn.begin(): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(chunks_table).values(chunks_table_data)) return ids @@ -185,11 +185,10 @@ class RelytVector(BaseVector): ) try: - with self.client.connect() as conn: - with conn.begin(): - delete_condition = chunks_table.c.id.in_(ids) - conn.execute(chunks_table.delete().where(delete_condition)) - return True + with self.client.connect() as conn, conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + conn.execute(chunks_table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index ada0c5cf46..dbedc1d4e9 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -63,10 +63,7 @@ class TencentVector(BaseVector): def _has_collection(self) -> bool: collections = self._db.list_collections() - for collection in collections: - if collection.collection_name == self._collection_name: - return True - return False + return any(collection.collection_name == self._collection_name for collection in collections) def _create_collection(self, dimension: int) -> None: lock_name = "vector_indexing_lock_{}".format(self._collection_name) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 0e4b3f67a1..7eaf189292 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel): program_name: str @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") @@ -123,20 +124,19 @@ class TiDBVector(BaseVector): texts = [d.page_content for d in documents] chunks_table_data = [] - with self._engine.connect() as conn: - with conn.begin(): - for id, text, meta, embedding in zip(ids, texts, metas, embeddings): - chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) + with self._engine.connect() as conn, conn.begin(): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): + chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(table).values(chunks_table_data)) return ids def text_exists(self, id: str) -> bool: @@ -159,11 +159,10 @@ class TiDBVector(BaseVector): raise ValueError("No ids provided to delete.") table = self._table(self._dimension) try: - with self._engine.connect() as conn: - with conn.begin(): - delete_condition = table.c.id.in_(ids) - conn.execute(table.delete().where(delete_condition)) - return True + with self._engine.connect() as conn, conn.begin(): + delete_condition = table.c.id.in_(ids) + conn.execute(table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 750172b015..ca1123c6a0 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel): batch_size: int = 100 @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 2db00d161b..c6f15e55b6 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor): raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path - self.temp_file = tempfile.NamedTemporaryFile() + # TODO: use a better way to handle the file + self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 self.temp_file.write(r.content) self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4375079ee5..16d6b879a4 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -120,8 +120,8 @@ class WeightRerankRunner: intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 4948ec6ba8..e4ad78ed2b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -581,8 +581,8 @@ class DatasetRetrieval: intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 09f30a59d6..7462824be1 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool): self.create_blob_message( blob=b64decode(image.b64_json), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 0d9613c0cf..8bed2c556c 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -71,7 +71,7 @@ class BingSearchTool(BuiltinTool): text = "" if search_results: for i, result in enumerate(search_results): - text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n' + text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n' if computation and "expression" in computation and "value" in computation: text += "\nComputation:\n" diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index ac7e394911..fbd7397292 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool): self.create_blob_message( blob=b64decode(image.b64_json), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 2d62cf608f..bcfa2212b6 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool): for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) blob_message = self.create_blob_message( - blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value ) result.append(blob_message) return result diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py index 4cad12e4ee..c68878630d 100644 --- a/api/core/tools/provider/builtin/did/did_appx.py +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -83,5 +83,5 @@ class DIDApp: if status["status"] == "done": return status elif status["status"] == "error" or status["status"] == "rejected": - raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') + raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}') time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index dceb37db49..45ab15f437 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -1,4 +1,5 @@ import json +import urllib.parse from datetime import datetime, timedelta from typing import Any, Union @@ -13,13 +14,14 @@ class GitlabCommitsTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") employee = tool_parameters.get("employee", "") start_time = tool_parameters.get("start_time", "") end_time = tool_parameters.get("end_time", "") change_type = tool_parameters.get("change_type", "all") - if not project: - return self.create_text_message("Project is required") + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() @@ -35,91 +37,105 @@ class GitlabCommitsTool(BuiltinTool): site_url = "https://gitlab.com" # Get commit content - result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) + if repository: + result = self.fetch_commits( + site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True + ) + else: + result = self.fetch_commits( + site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False + ) return [self.create_json_message(item) for item in result] - def fetch( + def fetch_commits( self, - user_id: str, site_url: str, access_token: str, - project: str, - employee: str = None, - start_time: str = "", - end_time: str = "", - change_type: str = "", + identifier: str, + employee: str, + start_time: str, + end_time: str, + change_type: str, + is_repository: bool, ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # Get all of projects - url = f"{domain}/api/v4/projects" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits" + else: + # Get all projects + url = f"{domain}/api/v4/projects" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() - filtered_projects = [p for p in projects if project == "*" or p["name"] == project] + filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier] - for project in filtered_projects: - project_id = project["id"] - project_name = project["name"] - print(f"Project: {project_name}") + for project in filtered_projects: + project_id = project["id"] + project_name = project["name"] + print(f"Project: {project_name}") - # Get all of project commits - commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = {"since": start_time, "until": end_time} - if employee: - params["author"] = employee + commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - commits_response = requests.get(commits_url, headers=headers, params=params) - commits_response.raise_for_status() - commits = commits_response.json() + params = {"since": start_time, "until": end_time} + if employee: + params["author"] = employee - for commit in commits: - commit_sha = commit["id"] - author_name = commit["author_name"] + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + for commit in commits: + commit_sha = commit["id"] + author_name = commit["author_name"] + + if is_repository: + diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff" + else: diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" - diff_response = requests.get(diff_url, headers=headers) - diff_response.raise_for_status() - diffs = diff_response.json() - for diff in diffs: - # Calculate code lines of changed - added_lines = diff["diff"].count("\n+") - removed_lines = diff["diff"].count("\n-") - total_changes = added_lines + removed_lines + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() - if change_type == "new": - if added_lines > 1: - final_code = "".join( - [ - line[1:] - for line in diff["diff"].split("\n") - if line.startswith("+") and not line.startswith("+++") - ] - ) - results.append( - {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code} - ) - else: - if total_changes > 1: - final_code = "".join( - [ - line[1:] - for line in diff["diff"].split("\n") - if (line.startswith("+") or line.startswith("-")) - and not line.startswith("+++") - and not line.startswith("---") - ] - ) - final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append( - {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} - ) + for diff in diffs: + # Calculate code lines of changes + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") + total_changes = added_lines + removed_lines + + if change_type == "new": + if added_lines > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}) + else: + if total_changes > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append( + {"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped} + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml index d38d943958..669378ac97 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -21,9 +21,20 @@ parameters: zh_Hans: 员工用户名 llm_description: User name for GitLab form: llm + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目名 diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 4a42b0fd73..7606eee7af 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -1,3 +1,4 @@ +import urllib.parse from typing import Any, Union import requests @@ -11,14 +12,14 @@ class GitlabFilesTool(BuiltinTool): self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") branch = tool_parameters.get("branch", "") path = tool_parameters.get("path", "") - if not project: - return self.create_text_message("Project is required") + if not project and not repository: + return self.create_text_message("Either project or repository is required") if not branch: return self.create_text_message("Branch is required") - if not path: return self.create_text_message("Path is required") @@ -30,21 +31,59 @@ class GitlabFilesTool(BuiltinTool): if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): site_url = "https://gitlab.com" - # Get project ID from project name - project_id = self.get_project_id(site_url, access_token, project) - if not project_id: - return self.create_text_message(f"Project '{project}' not found.") - - # Get commit content - result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + # Get file content + if repository: + result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True) + else: + result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False) return [self.create_json_message(item) for item in result] - def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split("/", 1) - if len(parts) < 2: - return None, None - return parts[0], parts[1] + def fetch_files( + self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool + ) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}" + else: + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, identifier) + if not project_id: + return self.create_text_message(f"Project '{identifier}' not found.") + tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + + response = requests.get(tree_url, headers=headers) + response.raise_for_status() + items = response.json() + + for item in items: + item_path = item["path"] + if item["type"] == "tree": # It's a directory + results.extend( + self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository) + ) + else: # It's a file + if is_repository: + file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}" + else: + file_url = ( + f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + ) + + file_response = requests.get(file_url, headers=headers) + file_response.raise_for_status() + file_content = file_response.text + results.append({"path": item_path, "branch": branch, "content": file_content}) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: headers = {"PRIVATE-TOKEN": access_token} @@ -59,32 +98,3 @@ class GitlabFilesTool(BuiltinTool): except requests.RequestException as e: print(f"Error fetching project ID from GitLab: {e}") return None - - def fetch( - self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None - ) -> list[dict[str, Any]]: - domain = site_url - headers = {"PRIVATE-TOKEN": access_token} - results = [] - - try: - # List files and directories in the given path - url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" - response = requests.get(url, headers=headers) - response.raise_for_status() - items = response.json() - - for item in items: - item_path = item["path"] - if item["type"] == "tree": # It's a directory - results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) - else: # It's a file - file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" - file_response = requests.get(file_url, headers=headers) - file_response.raise_for_status() - file_content = file_response.text - results.append({"path": item_path, "branch": branch, "content": file_content}) - except requests.RequestException as e: - print(f"Error fetching data from GitLab: {e}") - - return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml index d99b6254c1..4c733673f1 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -10,9 +10,20 @@ description: zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 llm: A tool for query GitLab files, Input should be a exists file or directory path. parameters: + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目 diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 7e9f70f8e5..71f8356ab8 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -142,7 +142,7 @@ class ListWorksheetRecordsTool(BuiltinTool): for control in controls: control_type_id = self.get_real_type_id(control) if (control_type_id in self._get_ignore_types()) or ( - allow_fields and not control["controlId"] in allow_fields + allow_fields and control["controlId"] not in allow_fields ): continue else: @@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool): elif value.startswith('[{"organizeId"'): value = json.loads(value) value = "、".join([item["organizeName"] for item in value]) - elif value.startswith('[{"file_id"'): - value = "" - elif value == "[]": + elif value.startswith('[{"file_id"') or value == "[]": value = "" elif hasattr(value, "accountId"): value = value["fullname"] diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index b4193f00bf..4dba2df1f1 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -67,7 +67,7 @@ class ListWorksheetsTool(BuiltinTool): items = [] tables = "" for item in section.get("items", []): - if item.get("type") == 0 and (not "notes" in item or item.get("notes") != "NO"): + if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"): if type == "json": filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")} items.append(filtered_item) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index f76587bea1..0b4f2edff3 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool): self.create_blob_message( blob=b64decode(client_result.image_file), meta={"mime_type": f"image/{client_result.image_type}"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index fe105f70a7..9ca14b327c 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool): models_data=[], headers=headers, params=params, - recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True, + recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), ) result_str = "" diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 9632c163cf..9c61eab9f9 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): self.create_blob_message( blob=b64decode(image_encoded), meta={"mime_type": f"image/{image.image_type}"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 8aefc65131..d8ca20bde6 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -39,14 +39,14 @@ class QRCodeGeneratorTool(BuiltinTool): # get error_correction error_correction = tool_parameters.get("error_correction", "") - if error_correction not in self.error_correction_levels.keys(): + if error_correction not in self.error_correction_levels: return self.create_text_message("Invalid parameter error_correction") try: image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) return self.create_blob_message( - blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) except Exception: logging.exception(f"Failed to generate QR code for content: {content}") diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index d632304a46..6d88d74635 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -44,36 +44,36 @@ class SearchAPI: @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + if "answer_box" in res and "answer" in res["answer_box"]: toret += res["answer_box"]["answer"] + "\n" - if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + if "answer_box" in res and "snippet" in res["answer_box"]: toret += res["answer_box"]["snippet"] + "\n" - if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): + if "knowledge_graph" in res and "description" in res["knowledge_graph"]: toret += res["knowledge_graph"]["description"] + "\n" - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys(): - if "title" in res["answer_box"]["organic_result"].keys(): + if "answer_box" in res and "organic_result" in res["answer_box"]: + if "title" in res["answer_box"]["organic_result"]: toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n" - elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys(): + elif "organic_results" in res and "link" in res["organic_results"][0]: toret = "" for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys(): + elif "related_questions" in res and "link" in res["related_questions"][0]: toret = "" for item in res["related_questions"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys(): + elif "related_searches" in res and "link" in res["related_searches"][0]: toret = "" for item in res["related_searches"]: toret += f"[{item['title']}]({item['link']})\n" diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 1544061c08..d29cb0ae3f 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -44,12 +44,12 @@ class SearchAPI: @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): + if "jobs" in res and "title" in res["jobs"][0]: for item in res["jobs"]: toret += ( "title: " @@ -65,7 +65,7 @@ class SearchAPI: toret = "No good search result found" elif type == "link": - if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys(): + if "jobs" in res and "apply_link" in res["jobs"][0]: for item in res["jobs"]: toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n" else: diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 95a7aad736..8458c8c958 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -44,25 +44,25 @@ class SearchAPI: @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" - if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + if "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys(): + if "organic_results" in res and "title" in res["organic_results"][0]: for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + elif "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += f"[{item['title']}]({item['link']})\n" else: diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 88def504fc..d7bfb53bd7 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -44,11 +44,11 @@ class SearchAPI: @staticmethod def _process_response(res: dict) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" - if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys(): + if "transcripts" in res and "text" in res["transcripts"][0]: for item in res["transcripts"]: toret += item["text"] + " " if toret == "": diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py index 5fa9926484..1b846624bd 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/flux.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -32,5 +32,5 @@ class FluxTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py index e7c3c28d7b..d6a0b03d1b 100644 --- a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -41,5 +41,5 @@ class StableDiffusionTool(BuiltinTool): res = response.json() result = [self.create_json_message(res)] for image in res.get("images", []): - result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index a6f5570af2..81d9e8d941 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -class AssembleHeaderException(Exception): +class AssembleHeaderError(Exception): def __init__(self, msg): self.message = msg class Url: - def __init__(this, host, path, schema): - this.host = host - this.path = path - this.schema = schema + def __init__(self, host, path, schema): + self.host = host + self.path = path + self.schema = schema # calculate sha256 and encode to base64 @@ -41,7 +41,7 @@ def parse_url(request_url): schema = request_url[: stidx + 3] edidx = host.index("/") if edidx <= 0: - raise AssembleHeaderException("invalid request url:" + request_url) + raise AssembleHeaderError("invalid request url:" + request_url) path = host[edidx:] host = host[:edidx] u = Url(host, path, schema) @@ -115,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool): self.create_blob_message( blob=b64decode(image["base64_image"]), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index c33e3bd78f..9f415ceb55 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -35,7 +35,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): if model in ["sd3", "sd3-turbo"]: payload["model"] = tool_parameters.get("model") - if not model == "sd3-turbo": + if model != "sd3-turbo": payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( @@ -52,5 +52,5 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): raise Exception(response.text) return self.create_blob_message( - blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index c31e178067..344f916494 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -206,10 +206,9 @@ class StableDiffusionTool(BuiltinTool): # Convert image to RGB and save as PNG try: - with Image.open(io.BytesIO(image_binary)) as image: - with io.BytesIO() as buffer: - image.convert("RGB").save(buffer, format="PNG") - image_binary = buffer.getvalue() + with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() except Exception as e: return self.create_text_message(f"Failed to process the image: {str(e)}") @@ -260,7 +259,7 @@ class StableDiffusionTool(BuiltinTool): image = response.json()["images"][0] return self.create_blob_message( - blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) except Exception as e: @@ -294,7 +293,7 @@ class StableDiffusionTool(BuiltinTool): image = response.json()["images"][0] return self.create_blob_message( - blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) except Exception as e: diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index 93803d7937..aeaef08805 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -45,5 +45,5 @@ class PoiSearchTool(BuiltinTool): ).content return self.create_blob_message( - blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value + blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index 3ba4996be1..4bd601c0bd 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -32,7 +32,7 @@ class VectorizerTool(BuiltinTool): if image_id.startswith("__test_"): image_binary = b64decode(VECTORIZER_ICON_PNG) else: - image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + image_binary = self.get_variable_file(self.VariableKey.IMAGE) if not image_binary: return self.create_text_message("Image not found, please request user to generate image firstly.") diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 67efcf0954..cb88e9519a 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -27,7 +27,7 @@ class WikipediaAPIWrapper: self.doc_content_chars_max = doc_content_chars_max def run(self, query: str, lang: str = "") -> str: - if lang in wikipedia.languages().keys(): + if lang in wikipedia.languages(): self.lang = lang wikipedia.set_lang(self.lang) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index ac3dc84db4..d9e9a0faad 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -63,7 +63,7 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VARIABLE_KEY(Enum): + class VariableKey(Enum): IMAGE = "image" def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": @@ -142,7 +142,7 @@ class Tool(BaseModel, ABC): if not self.variables: return None - return self.get_variable(self.VARIABLE_KEY.IMAGE) + return self.get_variable(self.VariableKey.IMAGE) def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ @@ -189,7 +189,7 @@ class Tool(BaseModel, ABC): result = [] for variable in self.variables.pool: - if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): + if variable.name.startswith(self.VariableKey.IMAGE.value): result.append(variable) return result diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index c4983ebc65..1109ed7df2 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -19,9 +19,7 @@ class ToolFileMessageTransformer: result = [] for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: + if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: result.append(message) elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 0b83ee10cd..c156dd8c98 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -224,9 +224,7 @@ class Graph(BaseModel): """ leaf_node_ids = [] for node_id in self.node_ids: - if node_id not in self.edge_mapping: - leaf_node_ids.append(node_id) - elif ( + if node_id not in self.edge_mapping or ( len(self.edge_mapping[node_id]) == 1 and self.edge_mapping[node_id][0].target_node_id == self.root_node_id ): @@ -310,7 +308,7 @@ class Graph(BaseModel): parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if not condition_hash in condition_edge_mappings: + if condition_hash not in condition_edge_mappings: condition_edge_mappings[condition_hash] = [] condition_edge_mappings[condition_hash].append(graph_edge) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index c6bd122b37..1db9b690ab 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -8,7 +8,7 @@ from typing import Any, Optional from flask import Flask, current_app -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import ( NodeRunMetadataKey, @@ -90,9 +90,9 @@ class GraphEngine: thread_pool_max_submit_count = 100 thread_pool_max_workers = 10 - ## init thread pool + # init thread pool if thread_pool_id: - if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + if thread_pool_id not in GraphEngine.workflow_thread_pool_mapping: raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") self.thread_pool_id = thread_pool_id @@ -669,7 +669,7 @@ class GraphEngine: parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: # trigger node run failed event route_node_state.status = RouteNodeState.Status.FAILED route_node_state.failed_reason = "Workflow stopped." diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 06050e1549..e31a1479a8 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -24,7 +24,7 @@ class AnswerStreamGeneratorRouter: # parse stream output node value selectors of answer nodes answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} for answer_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get("data", {}).get("type") == NodeType.ANSWER.value: + if node_config.get("data", {}).get("type") != NodeType.ANSWER.value: continue # get generate route for stream output diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 4a1787c8c1..a07ba2f740 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Optional, Union, cast from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider @@ -61,7 +61,7 @@ class CodeNode(BaseNode): # Transform result result = self._transform_result(result, node_data.outputs) - except (CodeExecutionException, ValueError) as e: + except (CodeExecutionError, ValueError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 30ce8fe018..a38d982393 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -17,7 +17,7 @@ class EndStreamGeneratorRouter: # parse stream output node value selector of end nodes end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} for end_node_id, node_config in node_id_config_mapping.items(): - if not node_config.get("data", {}).get("type") == NodeType.END.value: + if node_config.get("data", {}).get("type") != NodeType.END.value: continue # skip end node in parallel diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2829144ead..32c99e0d1c 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -2,7 +2,7 @@ import os from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData @@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode): result = CodeExecutor.execute_workflow_code_template( language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables ) - except CodeExecutionException as e: + except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 28fbf789fd..9d222b10b9 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -20,7 +20,7 @@ class ToolEntity(BaseModel): if not isinstance(value, dict): raise ValueError("tool_configurations must be a dictionary") - for key in values.data.get("tool_configurations", {}).keys(): + for key in values.data.get("tool_configurations", {}): value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): raise ValueError(f"{key} must be a string") diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 25021935ee..74a598ada5 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -6,7 +6,7 @@ from typing import Any, Optional, cast from configs import dify_config from core.app.app_config.entities import FileExtraConfig -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback @@ -103,7 +103,7 @@ class WorkflowEntry: for callback in callbacks: callback.on_event(event=event) yield event - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except Exception as e: logger.exception("Unknown Error when workflow entry running") diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 72a135e73d..54f6a76e16 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import time import click from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document @@ -43,7 +43,7 @@ def handle(sender, **kwargs): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index 9ed1fcf0b4..c42f946fa8 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -5,7 +5,7 @@ from collections.abc import Generator from contextlib import closing from flask import Flask -from google.cloud import storage as GoogleCloudStorage +from google.cloud import storage as google_cloud_storage from extensions.storage.base_storage import BaseStorage @@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage): service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) - self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) + self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) else: - self.client = GoogleCloudStorage.Client() + self.client = google_cloud_storage.Client() def save(self, filename, data): bucket = self.client.get_bucket(self.bucket_name) diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 2d306edb40..f89902c5e8 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord from Crypto.Util.strxor import strxor -class PKCS1OAEP_Cipher: +class PKCS1OAepCipher: """Cipher object for PKCS#1 v1.5 OAEP. Do not create directly: use :func:`new` instead.""" @@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): if randfunc is None: randfunc = Random.get_random_bytes - return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) + return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) diff --git a/api/libs/helper.py b/api/libs/helper.py index af0c2dace1..d664ef1ae7 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -84,7 +84,7 @@ def timestamp_value(timestamp): raise ValueError(error) -class str_len: +class StrLen: """Restrict input to an integer in a range (inclusive)""" def __init__(self, max_length, argument="argument"): @@ -102,7 +102,7 @@ class str_len: return value -class float_range: +class FloatRange: """Restrict input to an float in a range (inclusive)""" def __init__(self, low, high, argument="argument"): @@ -121,7 +121,7 @@ class float_range: return value -class datetime_string: +class DatetimeString: def __init__(self, format, argument="argument"): self.format = format self.argument = argument diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41d6905899..39c17534e7 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,6 +1,6 @@ import json -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError def parse_json_markdown(json_string: str) -> dict: @@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"Got invalid JSON object. Error: {e}") for key in expected_keys: if key not in json_obj: - raise OutputParserException( + raise OutputParserError( f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" ) return json_obj diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py index 0fba6a87eb..8cd4ec552b 100644 --- a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -24,6 +24,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('label') diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py index bfda7d619c..92f41f0abd 100644 --- a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -21,6 +21,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('version') diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index 2365766837..fcca705d21 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -99,7 +99,7 @@ def upgrade(): id=id, tenant_id=tenant_id, user_id=user_id, - provider='google', + provider='google', encrypted_credentials=encrypted_credentials, created_at=created_at, updated_at=updated_at diff --git a/api/pyproject.toml b/api/pyproject.toml index dc7e271ccf..3d100ebc58 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,30 +13,27 @@ preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions + "E", # pycodestyle E rules "F", # pyflakes rules "I", # isort rules - "UP", # pyupgrade rules - "B035", # static-key-dict-comprehension - "E101", # mixed-spaces-and-tabs - "E111", # indentation-with-invalid-multiple - "E112", # no-indented-block - "E113", # unexpected-indentation - "E115", # no-indented-block-comment - "E116", # unexpected-indentation-comment - "E117", # over-indented + "N", # pep8-naming "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa "S506", # unsafe-yaml-load - "SIM116", # if-else-block-instead-of-dict-lookup - "SIM401", # if-else-block-instead-of-dict-get - "SIM910", # dict-get-with-none-default + "SIM", # flake8-simplify rules + "UP", # pyupgrade rules "W191", # tab-indentation "W605", # invalid-escape-sequence - "F601", # multi-value-repeated-key-literal - "F602", # multi-value-repeated-key-variable ] ignore = [ + "E501", # line-too-long + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "E731", # lambda-assignment "F403", # undefined-local-with-import-star "F405", # undefined-local-with-import-star-usage "F821", # undefined-name @@ -47,9 +44,19 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg -# "B901", # return-in-generator "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # eumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "SIM300", # yoda-conditions ] [tool.ruff.lint.per-file-ignores] @@ -65,6 +72,12 @@ ignore = [ "F401", # unused-import "F811", # redefined-while-unused ] +"configs/*" = [ + "N802", # invalid-function-name +] +"libs/gmpy2_pkcs10aep_cipher.py" = [ + "N803", # invalid-argument-name +] [tool.ruff.format] exclude = [ diff --git a/api/services/account_service.py b/api/services/account_service.py index e1b70fc9ed..7fb42f9e81 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -32,7 +32,7 @@ from services.errors.account import ( NoPermissionError, RateLimitExceededError, RoleAlreadyAssignedError, - TenantNotFound, + TenantNotFoundError, ) from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -311,13 +311,13 @@ class TenantService: """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: - raise TenantNotFound("Tenant not found.") + raise TenantNotFoundError("Tenant not found.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: - raise TenantNotFound("Tenant not found for the account.") + raise TenantNotFoundError("Tenant not found for the account.") return tenant @staticmethod @@ -614,8 +614,8 @@ class RegisterService: "email": account.email, "workspace_id": tenant.id, } - expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) + expiry_hours = dify_config.INVITE_EXPIRY_HOURS + redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) return token @classmethod diff --git a/api/services/errors/account.py b/api/services/errors/account.py index cae31c5066..82dd9f944a 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -1,7 +1,7 @@ from services.errors.base import BaseServiceError -class AccountNotFound(BaseServiceError): +class AccountNotFoundError(BaseServiceError): pass @@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError): pass -class TenantNotFound(BaseServiceError): +class TenantNotFoundError(BaseServiceError): pass diff --git a/api/services/file_service.py b/api/services/file_service.py index 5780abb2be..bedec76334 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -56,9 +56,7 @@ class FileService: if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS ) - if extension.lower() not in allowed_extensions: - raise UnsupportedFileTypeError() - elif only_image and extension.lower() not in IMAGE_EXTENSIONS: + if extension.lower() not in allowed_extensions or only_image and extension.lower() not in IMAGE_EXTENSIONS: raise UnsupportedFileTypeError() # read file content diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1e7935d299..d8e2b1689a 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -54,7 +54,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider: + if tracing_provider not in provider_config_map and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} config_class, other_keys = ( @@ -113,7 +113,7 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys(): + if tracing_provider not in provider_config_map: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 9ea4c99649..6dd755ab03 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logging.info( click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e0da5f9ed0..72c4674e0f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 6e681bcf4f..cb38bc668d 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner.run([document]) end_at = time.perf_counter() logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 0a7568c385..f4c3dbd2e2 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -6,7 +6,7 @@ import click from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 18bae14ffa..21ea11d4dd 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -5,7 +5,7 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Document @@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logging.info( click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") ) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py index b37b109eba..83317e59de 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py @@ -70,6 +70,7 @@ class MockTEIClass: }, } + @staticmethod def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: # Example response: # [ diff --git a/api/tests/integration_tests/tools/__mock/http.py b/api/tests/integration_tests/tools/__mock/http.py index 4dfc530010..d3c1f3101c 100644 --- a/api/tests/integration_tests/tools/__mock/http.py +++ b/api/tests/integration_tests/tools/__mock/http.py @@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 571c1e3d44..53c9b3cae3 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -13,7 +13,7 @@ from xinference_client.types import Embedding class MockTcvectordbClass: - def VectorDBClient( + def mock_vector_db_client( self, url=None, username="", @@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) + monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/http.py b/api/tests/integration_tests/workflow/nodes/__mock/http.py index cfc47bcad4..f1ab23b002 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/http.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/http.py @@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true" class MockedHttp: + @staticmethod def httpx_request( method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs ) -> httpx.Response: diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py index 44dcf9a10f..487178ff58 100644 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -1,11 +1,11 @@ import pytest -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor CODE_LANGUAGE = "unsupported_language" def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionException) as e: + with pytest.raises(CodeExecutionError) as e: CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py index 65757cd604..13ba11016a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -247,9 +247,9 @@ def test_parallels_graph(): for i in range(3): start_edges = graph.edge_mapping.get("start") assert start_edges is not None - assert start_edges[i].target_node_id == f"llm{i+1}" + assert start_edges[i].target_node_id == f"llm{i + 1}" - llm_edges = graph.edge_mapping.get(f"llm{i+1}") + llm_edges = graph.edge_mapping.get(f"llm{i + 1}") assert llm_edges is not None assert llm_edges[0].target_node_id == "answer" diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index a90f86402d..fdb09dc3ae 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -159,7 +159,7 @@ const ChatInput: FC = ({ { visionConfig?.enabled && ( <> -
+
{ return ( <> -
{paragraph.children.slice(1)}
+

{paragraph.children.slice(1)}

) } - return
{paragraph.children}
+ return

{paragraph.children}

} const Img = ({ src }: any) => { diff --git a/web/app/components/develop/secret-key/secret-key-modal.tsx b/web/app/components/develop/secret-key/secret-key-modal.tsx index fd28a67e7e..dbb5cc37c7 100644 --- a/web/app/components/develop/secret-key/secret-key-modal.tsx +++ b/web/app/components/develop/secret-key/secret-key-modal.tsx @@ -41,7 +41,7 @@ const SecretKeyModal = ({ }: ISecretKeyModalProps) => { const { t } = useTranslation() const { formatTime } = useTimestamp() - const { currentWorkspace, isCurrentWorkspaceManager } = useAppContext() + const { currentWorkspace, isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() const [showConfirmDelete, setShowConfirmDelete] = useState(false) const [isVisible, setVisible] = useState(false) const [newKey, setNewKey] = useState(undefined) @@ -142,7 +142,7 @@ const SecretKeyModal = ({ ) }
- diff --git a/web/app/components/workflow/hooks/use-shortcuts.ts b/web/app/components/workflow/hooks/use-shortcuts.ts index 666c3a45ba..439b521a30 100644 --- a/web/app/components/workflow/hooks/use-shortcuts.ts +++ b/web/app/components/workflow/hooks/use-shortcuts.ts @@ -70,7 +70,8 @@ export const useShortcuts = (): void => { }) useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => { - if (shouldHandleShortcut(e)) { + const { showDebugAndPreviewPanel, showInputsPanel } = workflowStore.getState() + if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel && !showInputsPanel) { e.preventDefault() handleNodesCopy() } diff --git a/web/app/components/workflow/run/meta.tsx b/web/app/components/workflow/run/meta.tsx index 9f03bd7774..22adcb0636 100644 --- a/web/app/components/workflow/run/meta.tsx +++ b/web/app/components/workflow/run/meta.tsx @@ -16,7 +16,7 @@ type Props = { const MetaData: FC = ({ status, executor, - startTime = 0, + startTime, time, tokens, steps = 1, @@ -64,7 +64,7 @@ const MetaData: FC = ({
)} {status !== 'running' && ( - {formatTime(startTime, t('appLog.dateTimeFormat') as string)} + {startTime ? formatTime(startTime, t('appLog.dateTimeFormat') as string) : '-'} )}
@@ -75,7 +75,7 @@ const MetaData: FC = ({
)} {status !== 'running' && ( - {`${time?.toFixed(3)}s`} + {time ? `${time.toFixed(3)}s` : '-'} )}
diff --git a/web/app/components/workflow/run/status.tsx b/web/app/components/workflow/run/status.tsx index 6299302209..62167c32e0 100644 --- a/web/app/components/workflow/run/status.tsx +++ b/web/app/components/workflow/run/status.tsx @@ -67,7 +67,7 @@ const StatusPanel: FC = ({
)} {status !== 'running' && ( - {`${time?.toFixed(3)}s`} + {time ? `${time?.toFixed(3)}s` : '-'} )}
diff --git a/web/app/styles/markdown.scss b/web/app/styles/markdown.scss index b0a1f60cd2..214d8d2782 100644 --- a/web/app/styles/markdown.scss +++ b/web/app/styles/markdown.scss @@ -321,18 +321,12 @@ .markdown-body h4, .markdown-body h5, .markdown-body h6 { - margin-top: 24px; - margin-bottom: 16px; + padding-top: 12px; + margin-bottom: 12px; font-weight: var(--base-text-weight-semibold, 600); line-height: 1.25; } - -.markdown-body p { - margin-top: 0; - margin-bottom: 10px; -} - .markdown-body blockquote { margin: 0; padding: 0 8px; @@ -449,7 +443,7 @@ .markdown-body pre, .markdown-body details { margin-top: 0; - margin-bottom: 16px; + margin-bottom: 12px; } .markdown-body blockquote> :first-child {