Merge branch 'main' into feat/attachments

This commit is contained in:
StyleZhang 2024-09-12 13:53:41 +08:00
commit 323a835de9
153 changed files with 654 additions and 554 deletions

View File

@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner:
if not current_user.is_editor:
raise Forbidden()
current_key_count = (

View File

@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.helper import datetime_string
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import datetime_string
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode
@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import datetime_string
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """
@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """

View File

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import email, str_len, timezone
from libs.helper import StrLen, email, timezone
from libs.password import hash_password, valid_password
from models.account import AccountStatus
from services.account_service import RegisterService
@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"

View File

@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import str_len
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument("password", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import email, get_remote_ip, str_len
from libs.helper import StrLen, email, get_remote_ip
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()

View File

@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)
runner.run()
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()
else:
logger.exception(e)
raise e

View File

@ -21,7 +21,7 @@ class AudioTrunk:
self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
break
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
self.MAX_SENTENCE += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
break
future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None:
def check_and_get_audio(self) -> AudioTrunk | None:
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:

View File

@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query,
message_id=message_id,
)
except ModerationException as e:
except ModerationError as e:
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True

View File

@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.checkAndGetAudio()
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(

View File

@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
query=query,
message_id=message.id,
)
except ModerationException as e:
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,

View File

@ -171,5 +171,5 @@ class AppQueueManager:
)
class GenerateTaskStoppedException(Exception):
class GenerateTaskStoppedError(Exception):
pass

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(

View File

@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Conversation, Message
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
query=query,
message_id=message.id,
)
except ModerationException as e:
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
queue_manager=queue_manager,
message=message,
)
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(

View File

@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db
from models.model import App, Message
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
query=query,
message_id=message.id,
)
except ModerationException as e:
except ModerationError as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,

View File

@ -8,7 +8,7 @@ from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()
else:
logger.exception(e)
raise e

View File

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()

View File

@ -12,7 +12,7 @@ from pydantic import ValidationError
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
)
runner.run()
except GenerateTaskStoppedException:
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()
else:
logger.exception(e)
raise e

View File

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()
raise GenerateTaskStoppedError()

View File

@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listenAudioMsg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.checkAndGetAudio()
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -15,6 +15,7 @@ class Segment(BaseModel):
value: Any
@field_validator("value_type")
@classmethod
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.

View File

@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
elif isinstance(e, InvokeError | ValueError):
err = e
else:
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))

View File

@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
audio_response = self._listen_audio_msg(publisher, task_id)
if audio_response:
yield audio_response
else:
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None:
break
audio = publisher.checkAndGetAudio()
audio = publisher.check_and_get_audio()
if audio is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -67,7 +67,7 @@ class DatasetIndexToolCallbackHandler:
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,

View File

@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
class CodeExecutionException(Exception):
class CodeExecutionError(Exception):
pass
@ -86,15 +86,15 @@ class CodeExecutor:
),
)
if response.status_code == 503:
raise CodeExecutionException("Code execution service is unavailable")
raise CodeExecutionError("Code execution service is unavailable")
elif response.status_code != 200:
raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
)
except CodeExecutionException as e:
except CodeExecutionError as e:
raise e
except Exception as e:
raise CodeExecutionException(
raise CodeExecutionError(
"Failed to execute code, which is likely a network issue,"
" please check if the sandbox service is running."
f" ( Error: {str(e)} )"
@ -103,15 +103,15 @@ class CodeExecutor:
try:
response = response.json()
except:
raise CodeExecutionException("Failed to parse response")
raise CodeExecutionError("Failed to parse response")
if (code := response.get("code")) != 0:
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
response = CodeExecutionResponse(**response)
if response.data.error:
raise CodeExecutionException(response.data.error)
raise CodeExecutionError(response.data.error)
return response.data.stdout or ""
@ -126,13 +126,13 @@ class CodeExecutor:
"""
template_transformer = cls.code_template_transformers.get(language)
if not template_transformer:
raise CodeExecutionException(f"Unsupported language {language}")
raise CodeExecutionError(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs)
try:
response = cls.execute_code(language, preload, runner)
except CodeExecutionException as e:
except CodeExecutionError as e:
raise e
return template_transformer.transform_response(response)

View File

@ -78,8 +78,8 @@ class IndexingRunner:
dataset_document=dataset_document,
documents=documents,
)
except DocumentIsPausedException:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
@ -134,8 +134,8 @@ class IndexingRunner:
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
@ -192,8 +192,8 @@ class IndexingRunner:
self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
except DocumentIsPausedError:
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error"
dataset_document.error = str(e.description)
@ -756,7 +756,7 @@ class IndexingRunner:
indexing_cache_key = "document_{}_is_paused".format(document_id)
result = redis_client.get(indexing_cache_key)
if result:
raise DocumentIsPausedException()
raise DocumentIsPausedError()
@staticmethod
def _update_document_index_status(
@ -767,10 +767,10 @@ class IndexingRunner:
"""
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
raise DocumentIsPausedError()
document = DatasetDocument.query.filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedException()
raise DocumentIsDeletedPausedError()
update_params = {DatasetDocument.indexing_status: after_indexing_status}
@ -875,9 +875,9 @@ class IndexingRunner:
pass
class DocumentIsPausedException(Exception):
class DocumentIsPausedError(Exception):
pass
class DocumentIsDeletedPausedException(Exception):
class DocumentIsDeletedPausedError(Exception):
pass

View File

@ -1,2 +1,2 @@
class OutputParserException(Exception):
class OutputParserError(Exception):
pass

View File

@ -1,6 +1,6 @@
from typing import Any
from core.llm_generator.output_parser.errors import OutputParserException
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
raise ValueError("Expected 'opening_statement' to be a str.")
return parsed
except Exception as e:
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")

View File

@ -420,7 +420,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
),
)
index += 0
index += 1
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)

View File

@ -7,7 +7,7 @@ from requests import post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
@ -45,7 +45,7 @@ class BaichuanModel:
parameters: dict[str, Any],
tools: Optional[list[PromptMessageTool]] = None,
) -> dict[str, Any]:
if model in self._model_mapping.keys():
if model in self._model_mapping:
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
# we need to rename it to res_format to get its value
if parameters.get("res_format") == "json_object":
@ -94,7 +94,7 @@ class BaichuanModel:
timeout: int,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[Iterator, dict]:
if model in self._model_mapping.keys():
if model in self._model_mapping:
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
else:
raise BadRequestError(f"Unknown model: {model}")
@ -124,7 +124,7 @@ class BaichuanModel:
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err == "invalid_request_error":

View File

@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
class InsufficientAccountBalanceError(Exception):
pass

View File

@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],

View File

@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
if err == "invalid_api_key":
raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg)
raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg)
elif err and "rate" in err:
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],

View File

@ -52,6 +52,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.00025'
output: '0.00125'

View File

@ -52,6 +52,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.015'
output: '0.075'

View File

@ -51,6 +51,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'

View File

@ -51,6 +51,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.003'
output: '0.015'

View File

@ -45,6 +45,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.008'
output: '0.024'

View File

@ -45,6 +45,8 @@ parameter_rules:
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.008'
output: '0.024'

View File

@ -20,6 +20,7 @@ from botocore.exceptions import (
from PIL.Image import Image
# local import
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -44,6 +45,14 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
"""
class BedrockLargeLanguageModel(LargeLanguageModel):
@ -70,6 +79,40 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
logger.info(f"current model id: {model_id} did not support by Converse API")
return None
def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if model_parameters.get("response_format"):
stop = stop or []
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
response_format = model_parameters.pop("response_format")
format_prompt = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
"{{block}}", response_format
)
)
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = format_prompt
else:
prompt_messages.insert(0, format_prompt)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _invoke(
self,
model: str,

View File

@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -77,7 +77,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
inputs.append(text)
# Prepare the payload for the request
payload = {"input": inputs, "model": model, "options": {"use_mmap": "true"}}
payload = {"input": inputs, "model": model, "options": {"use_mmap": True}}
# Make the request to the Ollama API
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))

View File

@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import (
)
class _CommonOAI_API_Compat:
class _CommonOaiApiCompat:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""

View File

@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
"""
Model class for OpenAI large language model.
"""

View File

@ -6,10 +6,10 @@ import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
"""
Model class for OpenAI Compatible Speech to text model.
"""

View File

@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
"""
Model class for an OpenAI API-compatible text embedding model.
"""

View File

@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
"""
Model class for an OpenAI API-compatible text embedding model.
"""

View File

@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
break
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -115,7 +115,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
token = credentials.token
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region
if "opus" or "claude-3-5-sonnet" in model:
if "opus" in model or "claude-3-5-sonnet" in model:
location = "us-east5"
else:
location = "us-central1"
@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService
class MaaSClient(MaasService):
@ -106,7 +106,7 @@ class MaaSClient(MaasService):
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try:
resp = fn()
except MaasException as e:
except MaasError as e:
raise wrap_error(e)
return resp

View File

@ -1,144 +1,144 @@
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError
class ClientSDKRequestError(MaasException):
class ClientSDKRequestError(MaasError):
pass
class SignatureDoesNotMatch(MaasException):
class SignatureDoesNotMatchError(MaasError):
pass
class RequestTimeout(MaasException):
class RequestTimeoutError(MaasError):
pass
class ServiceConnectionTimeout(MaasException):
class ServiceConnectionTimeoutError(MaasError):
pass
class MissingAuthenticationHeader(MaasException):
class MissingAuthenticationHeaderError(MaasError):
pass
class AuthenticationHeaderIsInvalid(MaasException):
class AuthenticationHeaderIsInvalidError(MaasError):
pass
class InternalServiceError(MaasException):
class InternalServiceError(MaasError):
pass
class MissingParameter(MaasException):
class MissingParameterError(MaasError):
pass
class InvalidParameter(MaasException):
class InvalidParameterError(MaasError):
pass
class AuthenticationExpire(MaasException):
class AuthenticationExpireError(MaasError):
pass
class EndpointIsInvalid(MaasException):
class EndpointIsInvalidError(MaasError):
pass
class EndpointIsNotEnable(MaasException):
class EndpointIsNotEnableError(MaasError):
pass
class ModelNotSupportStreamMode(MaasException):
class ModelNotSupportStreamModeError(MaasError):
pass
class ReqTextExistRisk(MaasException):
class ReqTextExistRiskError(MaasError):
pass
class RespTextExistRisk(MaasException):
class RespTextExistRiskError(MaasError):
pass
class EndpointRateLimitExceeded(MaasException):
class EndpointRateLimitExceededError(MaasError):
pass
class ServiceConnectionRefused(MaasException):
class ServiceConnectionRefusedError(MaasError):
pass
class ServiceConnectionClosed(MaasException):
class ServiceConnectionClosedError(MaasError):
pass
class UnauthorizedUserForEndpoint(MaasException):
class UnauthorizedUserForEndpointError(MaasError):
pass
class InvalidEndpointWithNoURL(MaasException):
class InvalidEndpointWithNoURLError(MaasError):
pass
class EndpointAccountRpmRateLimitExceeded(MaasException):
class EndpointAccountRpmRateLimitExceededError(MaasError):
pass
class EndpointAccountTpmRateLimitExceeded(MaasException):
class EndpointAccountTpmRateLimitExceededError(MaasError):
pass
class ServiceResourceWaitQueueFull(MaasException):
class ServiceResourceWaitQueueFullError(MaasError):
pass
class EndpointIsPending(MaasException):
class EndpointIsPendingError(MaasError):
pass
class ServiceNotOpen(MaasException):
class ServiceNotOpenError(MaasError):
pass
AuthErrors = {
"SignatureDoesNotMatch": SignatureDoesNotMatch,
"MissingAuthenticationHeader": MissingAuthenticationHeader,
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid,
"AuthenticationExpire": AuthenticationExpire,
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint,
"SignatureDoesNotMatch": SignatureDoesNotMatchError,
"MissingAuthenticationHeader": MissingAuthenticationHeaderError,
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError,
"AuthenticationExpire": AuthenticationExpireError,
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError,
}
BadRequestErrors = {
"MissingParameter": MissingParameter,
"InvalidParameter": InvalidParameter,
"EndpointIsInvalid": EndpointIsInvalid,
"EndpointIsNotEnable": EndpointIsNotEnable,
"ModelNotSupportStreamMode": ModelNotSupportStreamMode,
"ReqTextExistRisk": ReqTextExistRisk,
"RespTextExistRisk": RespTextExistRisk,
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURL,
"ServiceNotOpen": ServiceNotOpen,
"MissingParameter": MissingParameterError,
"InvalidParameter": InvalidParameterError,
"EndpointIsInvalid": EndpointIsInvalidError,
"EndpointIsNotEnable": EndpointIsNotEnableError,
"ModelNotSupportStreamMode": ModelNotSupportStreamModeError,
"ReqTextExistRisk": ReqTextExistRiskError,
"RespTextExistRisk": RespTextExistRiskError,
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError,
"ServiceNotOpen": ServiceNotOpenError,
}
RateLimitErrors = {
"EndpointRateLimitExceeded": EndpointRateLimitExceeded,
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded,
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded,
"EndpointRateLimitExceeded": EndpointRateLimitExceededError,
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError,
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError,
}
ServerUnavailableErrors = {
"InternalServiceError": InternalServiceError,
"EndpointIsPending": EndpointIsPending,
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull,
"EndpointIsPending": EndpointIsPendingError,
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError,
}
ConnectionErrors = {
"ClientSDKRequestError": ClientSDKRequestError,
"RequestTimeout": RequestTimeout,
"ServiceConnectionTimeout": ServiceConnectionTimeout,
"ServiceConnectionRefused": ServiceConnectionRefused,
"ServiceConnectionClosed": ServiceConnectionClosed,
"RequestTimeout": RequestTimeoutError,
"ServiceConnectionTimeout": ServiceConnectionTimeoutError,
"ServiceConnectionRefused": ServiceConnectionRefusedError,
"ServiceConnectionClosed": ServiceConnectionClosedError,
}
ErrorCodeMap = {
@ -150,7 +150,7 @@ ErrorCodeMap = {
}
def wrap_error(e: MaasException) -> Exception:
def wrap_error(e: MaasError) -> Exception:
if ErrorCodeMap.get(e.code):
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
return e

View File

@ -1,4 +1,4 @@
from .common import ChatRole
from .maas import MaasException, MaasService
from .maas import MaasError, MaasService
__all__ = ["MaasService", "ChatRole", "MaasException"]
__all__ = ["MaasService", "ChatRole", "MaasError"]

View File

@ -74,7 +74,7 @@ class Signer:
def sign(request, credentials):
if request.path == "":
request.path = "/"
if request.method != "GET" and not ("Content-Type" in request.headers):
if request.method != "GET" and "Content-Type" not in request.headers:
request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8"
format_date = Signer.get_current_format_date()

View File

@ -31,7 +31,7 @@ class Service:
self.service_info.scheme = scheme
def get(self, api, params, doseq=0):
if not (api in self.api_info):
if api not in self.api_info:
raise Exception("no such api")
api_info = self.api_info[api]
@ -49,7 +49,7 @@ class Service:
raise Exception(resp.text)
def post(self, api, params, form):
if not (api in self.api_info):
if api not in self.api_info:
raise Exception("no such api")
api_info = self.api_info[api]
r = self.prepare_request(api_info, params)
@ -72,7 +72,7 @@ class Service:
raise Exception(resp.text)
def json(self, api, params, body):
if not (api in self.api_info):
if api not in self.api_info:
raise Exception("no such api")
api_info = self.api_info[api]
r = self.prepare_request(api_info, params)

View File

@ -63,7 +63,7 @@ class MaasService(Service):
raise
if res.error is not None and res.error.code_n != 0:
raise MaasException(
raise MaasError(
res.error.code_n,
res.error.code,
res.error.message,
@ -72,7 +72,7 @@ class MaasService(Service):
yield res
return iter_fn()
except MaasException:
except MaasError:
raise
except Exception as e:
raise new_client_sdk_request_error(str(e))
@ -94,7 +94,7 @@ class MaasService(Service):
resp["req_id"] = req_id
return resp
except MaasException as e:
except MaasError as e:
raise e
except Exception as e:
raise new_client_sdk_request_error(str(e), req_id)
@ -109,7 +109,7 @@ class MaasService(Service):
if not self._apikey and not credentials_exist:
raise new_client_sdk_request_error("no valid credential", req_id)
if not (api in self.api_info):
if api not in self.api_info:
raise new_client_sdk_request_error("no such api", req_id)
def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
@ -147,14 +147,14 @@ class MaasService(Service):
raise new_client_sdk_request_error(raw, req_id)
if resp.error:
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id)
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id)
else:
raise new_client_sdk_request_error(resp, req_id)
return res
class MaasException(Exception):
class MaasError(Exception):
def __init__(self, code_n, code, message, req_id):
self.code_n = code_n
self.code = code
@ -172,7 +172,7 @@ class MaasException(Exception):
def new_client_sdk_request_error(raw, req_id=""):
return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
class BinaryResponseContent:
@ -192,7 +192,7 @@ class BinaryResponseContent:
if len(error_bytes) > 0:
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
def iter_bytes(self) -> Iterator[bytes]:
yield from self.response

View File

@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
MaasError,
RateLimitErrors,
ServerUnavailableErrors,
)
@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
},
[UserPromptMessage(content="ping\nAnswer: ")],
)
except MaasException as e:
except MaasError as e:
raise CredentialsValidateFailedError(e.message)
@staticmethod

View File

@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors,
BadRequestErrors,
ConnectionErrors,
MaasException,
MaasError,
RateLimitErrors,
ServerUnavailableErrors,
)
@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
try:
self._invoke(model=model, credentials=credentials, texts=["ping"])
except MaasException as e:
except MaasError as e:
raise CredentialsValidateFailedError(e.message)
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:

View File

@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InsufficientAccountBalanceError,
InvalidAPIKeyError,
],
InvokeBadRequestError: [BadRequestError, KeyError],
@ -42,7 +42,7 @@ class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
class InsufficientAccountBalanceError(Exception):
pass

View File

@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
"""
text = ""
for item in message:
if isinstance(item, UserPromptMessage):
text += item.content
elif isinstance(item, SystemPromptMessage):
text += item.content
elif isinstance(item, AssistantPromptMessage):
if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage):
text += item.content
else:
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")

View File

@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
):
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else:
if copy_prompt_message.role == PromptMessageRole.USER:
new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.TOOL:
if (
copy_prompt_message.role == PromptMessageRole.USER
or copy_prompt_message.role == PromptMessageRole.TOOL
):
new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
elif isinstance(message, ToolPromptMessage):
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")

View File

@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):
@ -111,5 +111,5 @@ class Moderation(Extensible, ABC):
raise ValueError("outputs_config.preset_response must be less than 100 characters")
class ModerationException(Exception):
class ModerationError(Exception):
pass

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationException
from core.moderation.base import ModerationAction, ModerationError
from core.moderation.factory import ModerationFactory
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
@ -61,7 +61,7 @@ class InputModeration:
return False, inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
raise ModerationError(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDDEN:
inputs = moderation_result.inputs
query = moderation_result.query

View File

@ -56,14 +56,7 @@ class KeywordsModeration(Moderation):
)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
for value in inputs.values():
if self._check_keywords_in_value(keywords_list, value):
return True
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
return False
def _check_keywords_in_value(self, keywords_list, value):
for keyword in keywords_list:
if keyword.lower() in value.lower():
return True
return False
def _check_keywords_in_value(self, keywords_list, value) -> bool:
return any(keyword.lower() in value.lower() for keyword in keywords_list)

View File

@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig):
host: str = "https://api.langfuse.com"
@field_validator("host")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.langfuse.com"
@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig):
endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.smith.langchain.com"

View File

@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel):
metadata: dict[str, Any]
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
if v is None:
return None

View File

@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel):
)
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel):
)
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
@ -196,6 +198,7 @@ class GenerationUsage(BaseModel):
totalCost: Optional[float] = None
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)
@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
return validate_input_output(v, field_name)

View File

@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
@field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name
values = info.data
@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
return v
return v
@classmethod
@field_validator("start_time", "end_time")
def format_time(cls, v, info: ValidationInfo):
if not isinstance(v, datetime):

View File

@ -223,7 +223,7 @@ class OpsTraceManager:
:return:
"""
# auth check
if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
if tracing_provider not in provider_config_map and tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: App = db.session.query(App).filter(App.id == app_id).first()

View File

@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel):
password: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config HOST is required")

View File

@ -28,6 +28,7 @@ class MilvusConfig(BaseModel):
database: str = "default"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")

View File

@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel):
secure: bool = False
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required")

View File

@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ORACLE_HOST is required")

View File

@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required")

View File

@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required")

View File

@ -34,6 +34,7 @@ class RelytConfig(BaseModel):
database: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config RELYT_HOST is required")
@ -126,27 +127,26 @@ class RelytVector(BaseVector):
)
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
with self.client.connect() as conn, conn.begin():
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
return ids
@ -185,11 +185,10 @@ class RelytVector(BaseVector):
)
try:
with self.client.connect() as conn:
with conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
with self.client.connect() as conn, conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e))
return False

View File

@ -63,10 +63,7 @@ class TencentVector(BaseVector):
def _has_collection(self) -> bool:
collections = self._db.list_collections()
for collection in collections:
if collection.collection_name == self._collection_name:
return True
return False
return any(collection.collection_name == self._collection_name for collection in collections)
def _create_collection(self, dimension: int) -> None:
lock_name = "vector_indexing_lock_{}".format(self._collection_name)

View File

@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel):
program_name: str
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required")
@ -123,20 +124,19 @@ class TiDBVector(BaseVector):
texts = [d.page_content for d in documents]
chunks_table_data = []
with self._engine.connect() as conn:
with conn.begin():
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
with self._engine.connect() as conn, conn.begin():
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(table).values(chunks_table_data))
return ids
def text_exists(self, id: str) -> bool:
@ -159,11 +159,10 @@ class TiDBVector(BaseVector):
raise ValueError("No ids provided to delete.")
table = self._table(self._dimension)
try:
with self._engine.connect() as conn:
with conn.begin():
delete_condition = table.c.id.in_(ids)
conn.execute(table.delete().where(delete_condition))
return True
with self._engine.connect() as conn, conn.begin():
delete_condition = table.c.id.in_(ids)
conn.execute(table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e))
return False

View File

@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel):
batch_size: int = 100
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required")

View File

@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor):
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
# TODO: use a better way to handle the file
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):

View File

@ -120,8 +120,8 @@ class WeightRerankRunner:
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum(vec1[x] * vec2[x] for x in intersection)
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
sum1 = sum(vec1[x] ** 2 for x in vec1)
sum2 = sum(vec2[x] ** 2 for x in vec2)
denominator = math.sqrt(sum1) * math.sqrt(sum2)
if not denominator:

View File

@ -581,8 +581,8 @@ class DatasetRetrieval:
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum(vec1[x] * vec2[x] for x in intersection)
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
sum1 = sum(vec1[x] ** 2 for x in vec1)
sum2 = sum(vec2[x] ** 2 for x in vec2)
denominator = math.sqrt(sum1) * math.sqrt(sum2)
if not denominator:

View File

@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool):
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))

View File

@ -71,7 +71,7 @@ class BingSearchTool(BuiltinTool):
text = ""
if search_results:
for i, result in enumerate(search_results):
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
if computation and "expression" in computation and "value" in computation:
text += "\nComputation:\n"

View File

@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool):
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)

View File

@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
)
result.append(blob_message)
return result

View File

@ -83,5 +83,5 @@ class DIDApp:
if status["status"] == "done":
return status
elif status["status"] == "error" or status["status"] == "rejected":
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}')
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
time.sleep(poll_interval)

View File

@ -1,4 +1,5 @@
import json
import urllib.parse
from datetime import datetime, timedelta
from typing import Any, Union
@ -13,13 +14,14 @@ class GitlabCommitsTool(BuiltinTool):
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "")
employee = tool_parameters.get("employee", "")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
change_type = tool_parameters.get("change_type", "all")
if not project:
return self.create_text_message("Project is required")
if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
@ -35,91 +37,105 @@ class GitlabCommitsTool(BuiltinTool):
site_url = "https://gitlab.com"
# Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
if repository:
result = self.fetch_commits(
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
)
else:
result = self.fetch_commits(
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
)
return [self.create_json_message(item) for item in result]
def fetch(
def fetch_commits(
self,
user_id: str,
site_url: str,
access_token: str,
project: str,
employee: str = None,
start_time: str = "",
end_time: str = "",
change_type: str = "",
identifier: str,
employee: str,
start_time: str,
end_time: str,
change_type: str,
is_repository: bool,
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# Get all of projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
else:
# Get all projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
filtered_projects = [p for p in projects if project == "*" or p["name"] == project]
filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
for project in filtered_projects:
project_id = project["id"]
project_name = project["name"]
print(f"Project: {project_name}")
for project in filtered_projects:
project_id = project["id"]
project_name = project["name"]
print(f"Project: {project_name}")
# Get all of project commits
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
params = {"since": start_time, "until": end_time}
if employee:
params["author"] = employee
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
params = {"since": start_time, "until": end_time}
if employee:
params["author"] = employee
for commit in commits:
commit_sha = commit["id"]
author_name = commit["author_name"]
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
for commit in commits:
commit_sha = commit["id"]
author_name = commit["author_name"]
if is_repository:
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
else:
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
for diff in diffs:
# Calculate code lines of changed
added_lines = diff["diff"].count("\n+")
removed_lines = diff["diff"].count("\n-")
total_changes = added_lines + removed_lines
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
if change_type == "new":
if added_lines > 1:
final_code = "".join(
[
line[1:]
for line in diff["diff"].split("\n")
if line.startswith("+") and not line.startswith("+++")
]
)
results.append(
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}
)
else:
if total_changes > 1:
final_code = "".join(
[
line[1:]
for line in diff["diff"].split("\n")
if (line.startswith("+") or line.startswith("-"))
and not line.startswith("+++")
and not line.startswith("---")
]
)
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append(
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
)
for diff in diffs:
# Calculate code lines of changes
added_lines = diff["diff"].count("\n+")
removed_lines = diff["diff"].count("\n-")
total_changes = added_lines + removed_lines
if change_type == "new":
if added_lines > 1:
final_code = "".join(
[
line[1:]
for line in diff["diff"].split("\n")
if line.startswith("+") and not line.startswith("+++")
]
)
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
else:
if total_changes > 1:
final_code = "".join(
[
line[1:]
for line in diff["diff"].split("\n")
if (line.startswith("+") or line.startswith("-"))
and not line.startswith("+++")
and not line.startswith("---")
]
)
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append(
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
)
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")

View File

@ -21,9 +21,20 @@ parameters:
zh_Hans: 员工用户名
llm_description: User name for GitLab
form: llm
- name: repository
type: string
required: false
label:
en_US: repository
zh_Hans: 仓库路径
human_description:
en_US: repository
zh_Hans: 仓库路径以namespace/project_name的形式。
llm_description: Repository path for GitLab, like namespace/project_name.
form: llm
- name: project
type: string
required: true
required: false
label:
en_US: project
zh_Hans: 项目名

View File

@ -1,3 +1,4 @@
import urllib.parse
from typing import Any, Union
import requests
@ -11,14 +12,14 @@ class GitlabFilesTool(BuiltinTool):
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "")
branch = tool_parameters.get("branch", "")
path = tool_parameters.get("path", "")
if not project:
return self.create_text_message("Project is required")
if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not branch:
return self.create_text_message("Branch is required")
if not path:
return self.create_text_message("Path is required")
@ -30,21 +31,59 @@ class GitlabFilesTool(BuiltinTool):
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
site_url = "https://gitlab.com"
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, project)
if not project_id:
return self.create_text_message(f"Project '{project}' not found.")
# Get commit content
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
# Get file content
if repository:
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
else:
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
return [self.create_json_message(item) for item in result]
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
parts = path.split("/", 1)
if len(parts) < 2:
return None, None
return parts[0], parts[1]
def fetch_files(
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
else:
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, identifier)
if not project_id:
return self.create_text_message(f"Project '{identifier}' not found.")
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(tree_url, headers=headers)
response.raise_for_status()
items = response.json()
for item in items:
item_path = item["path"]
if item["type"] == "tree": # It's a directory
results.extend(
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
)
else: # It's a file
if is_repository:
file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}"
else:
file_url = (
f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
)
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({"path": item_path, "branch": branch, "content": file_content})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
headers = {"PRIVATE-TOKEN": access_token}
@ -59,32 +98,3 @@ class GitlabFilesTool(BuiltinTool):
except requests.RequestException as e:
print(f"Error fetching project ID from GitLab: {e}")
return None
def fetch(
self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# List files and directories in the given path
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(url, headers=headers)
response.raise_for_status()
items = response.json()
for item in items:
item_path = item["path"]
if item["type"] == "tree": # It's a directory
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
else: # It's a file
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({"path": item_path, "branch": branch, "content": file_content})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -10,9 +10,20 @@ description:
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
llm: A tool for query GitLab files, Input should be a exists file or directory path.
parameters:
- name: repository
type: string
required: false
label:
en_US: repository
zh_Hans: 仓库路径
human_description:
en_US: repository
zh_Hans: 仓库路径以namespace/project_name的形式。
llm_description: Repository path for GitLab, like namespace/project_name.
form: llm
- name: project
type: string
required: true
required: false
label:
en_US: project
zh_Hans: 项目

View File

@ -142,7 +142,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
for control in controls:
control_type_id = self.get_real_type_id(control)
if (control_type_id in self._get_ignore_types()) or (
allow_fields and not control["controlId"] in allow_fields
allow_fields and control["controlId"] not in allow_fields
):
continue
else:
@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
elif value.startswith('[{"organizeId"'):
value = json.loads(value)
value = "".join([item["organizeName"] for item in value])
elif value.startswith('[{"file_id"'):
value = ""
elif value == "[]":
elif value.startswith('[{"file_id"') or value == "[]":
value = ""
elif hasattr(value, "accountId"):
value = value["fullname"]

View File

@ -67,7 +67,7 @@ class ListWorksheetsTool(BuiltinTool):
items = []
tables = ""
for item in section.get("items", []):
if item.get("type") == 0 and (not "notes" in item or item.get("notes") != "NO"):
if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"):
if type == "json":
filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")}
items.append(filtered_item)

View File

@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool):
self.create_blob_message(
blob=b64decode(client_result.image_file),
meta={"mime_type": f"image/{client_result.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)

View File

@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
models_data=[],
headers=headers,
params=params,
recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
)
result_str = ""

View File

@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
self.create_blob_message(
blob=b64decode(image_encoded),
meta={"mime_type": f"image/{image.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value,
save_as=self.VariableKey.IMAGE.value,
)
)

View File

@ -39,14 +39,14 @@ class QRCodeGeneratorTool(BuiltinTool):
# get error_correction
error_correction = tool_parameters.get("error_correction", "")
if error_correction not in self.error_correction_levels.keys():
if error_correction not in self.error_correction_levels:
return self.create_text_message("Invalid parameter error_correction")
try:
image = self._generate_qrcode(content, border, error_correction)
image_bytes = self._image_to_byte_array(image)
return self.create_blob_message(
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)
except Exception:
logging.exception(f"Failed to generate QR code for content: {content}")

Some files were not shown because too many files have changed in this diff Show More