mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 02:45:57 +08:00
Merge branch 'main' into feat/attachments
This commit is contained in:
commit
323a835de9
@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
|
||||
def post(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
@ -20,7 +20,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
|
||||
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
|
@ -8,7 +8,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, timezone
|
||||
from libs.helper import StrLen, email, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
@ -37,7 +37,7 @@ class ActivateApi(Resource):
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
|
@ -4,7 +4,7 @@ from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import str_len
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
|
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import email, get_remote_ip, str_len
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@ -40,7 +40,7 @@ class SetupApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
@ -21,7 +21,7 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
|
@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
@ -171,5 +171,5 @@ class AppQueueManager:
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Message
|
||||
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
@ -8,7 +8,7 @@ from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
@ -1,4 +1,4 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
@ -1,4 +1,4 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -15,6 +15,7 @@ class Segment(BaseModel):
|
||||
value: Any
|
||||
|
||||
@field_validator("value_type")
|
||||
@classmethod
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
|
@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
elif isinstance(e, InvokeError | ValueError):
|
||||
err = e
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||
|
@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
audio = publisher.check_and_get_audio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -67,7 +67,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
data_source_type=item.get("data_source_type"),
|
||||
segment_id=item.get("segment_id"),
|
||||
score=item.get("score") if "score" in item else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" in item else None,
|
||||
word_count=item.get("word_count") if "word_count" in item else None,
|
||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
||||
|
@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
class CodeExecutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -86,15 +86,15 @@ class CodeExecutor:
|
||||
),
|
||||
)
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionException("Code execution service is unavailable")
|
||||
raise CodeExecutionError("Code execution service is unavailable")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise CodeExecutionException(
|
||||
raise CodeExecutionError(
|
||||
"Failed to execute code, which is likely a network issue,"
|
||||
" please check if the sandbox service is running."
|
||||
f" ( Error: {str(e)} )"
|
||||
@ -103,15 +103,15 @@ class CodeExecutor:
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
raise CodeExecutionException("Failed to parse response")
|
||||
raise CodeExecutionError("Failed to parse response")
|
||||
|
||||
if (code := response.get("code")) != 0:
|
||||
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
|
||||
response = CodeExecutionResponse(**response)
|
||||
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
raise CodeExecutionError(response.data.error)
|
||||
|
||||
return response.data.stdout or ""
|
||||
|
||||
@ -126,13 +126,13 @@ class CodeExecutor:
|
||||
"""
|
||||
template_transformer = cls.code_template_transformers.get(language)
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f"Unsupported language {language}")
|
||||
raise CodeExecutionError(f"Unsupported language {language}")
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
|
@ -78,8 +78,8 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -134,8 +134,8 @@ class IndexingRunner:
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -192,8 +192,8 @@ class IndexingRunner:
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -756,7 +756,7 @@ class IndexingRunner:
|
||||
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
||||
result = redis_client.get(indexing_cache_key)
|
||||
if result:
|
||||
raise DocumentIsPausedException()
|
||||
raise DocumentIsPausedError()
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
@ -767,10 +767,10 @@ class IndexingRunner:
|
||||
"""
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
raise DocumentIsPausedError()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedException()
|
||||
raise DocumentIsDeletedPausedError()
|
||||
|
||||
update_params = {DatasetDocument.indexing_status: after_indexing_status}
|
||||
|
||||
@ -875,9 +875,9 @@ class IndexingRunner:
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
class DocumentIsPausedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsDeletedPausedException(Exception):
|
||||
class DocumentIsDeletedPausedError(Exception):
|
||||
pass
|
||||
|
@ -1,2 +1,2 @@
|
||||
class OutputParserException(Exception):
|
||||
class OutputParserError(Exception):
|
||||
pass
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserException
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import (
|
||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
|
||||
raise ValueError("Expected 'opening_statement' to be a str.")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||
raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||
|
@ -420,7 +420,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
|
||||
index += 0
|
||||
index += 1
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||
|
@ -7,7 +7,7 @@ from requests import post
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -45,7 +45,7 @@ class BaichuanModel:
|
||||
parameters: dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if model in self._model_mapping.keys():
|
||||
if model in self._model_mapping:
|
||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||
# we need to rename it to res_format to get its value
|
||||
if parameters.get("res_format") == "json_object":
|
||||
@ -94,7 +94,7 @@ class BaichuanModel:
|
||||
timeout: int,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[Iterator, dict]:
|
||||
if model in self._model_mapping.keys():
|
||||
if model in self._model_mapping:
|
||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
@ -124,7 +124,7 @@ class BaichuanModel:
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err == "invalid_request_error":
|
||||
|
@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
class InsufficientAccountBalanceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
|
@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and "rate" in err:
|
||||
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.075'
|
||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
|
@ -45,6 +45,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
@ -45,6 +45,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
|
@ -20,6 +20,7 @@ from botocore.exceptions import (
|
||||
from PIL.Image import Image
|
||||
|
||||
# local import
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -44,6 +45,14 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
@ -70,6 +79,40 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||
return None
|
||||
|
||||
def _code_block_mode_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if model_parameters.get("response_format"):
|
||||
stop = stop or []
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
response_format = model_parameters.pop("response_format")
|
||||
format_prompt = SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
||||
"{{block}}", response_format
|
||||
)
|
||||
)
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = format_prompt
|
||||
else:
|
||||
prompt_messages.insert(0, format_prompt)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
@ -77,7 +77,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
||||
inputs.append(text)
|
||||
|
||||
# Prepare the payload for the request
|
||||
payload = {"input": inputs, "model": model, "options": {"use_mmap": "true"}}
|
||||
payload = {"input": inputs, "model": model, "options": {"use_mmap": True}}
|
||||
|
||||
# Make the request to the Ollama API
|
||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||
|
@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
|
||||
|
||||
class _CommonOAI_API_Compat:
|
||||
class _CommonOaiApiCompat:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
|
@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
"""
|
||||
Model class for OpenAI large language model.
|
||||
"""
|
||||
|
@ -6,10 +6,10 @@ import requests
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
|
||||
class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
|
||||
"""
|
||||
Model class for OpenAI Compatible Speech to text model.
|
||||
"""
|
||||
|
@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
break
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
@ -115,7 +115,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
token = credentials.token
|
||||
|
||||
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region
|
||||
if "opus" or "claude-3-5-sonnet" in model:
|
||||
if "opus" in model or "claude-3-5-sonnet" in model:
|
||||
location = "us-east5"
|
||||
else:
|
||||
location = "us-central1"
|
||||
@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService
|
||||
|
||||
|
||||
class MaaSClient(MaasService):
|
||||
@ -106,7 +106,7 @@ class MaaSClient(MaasService):
|
||||
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||
try:
|
||||
resp = fn()
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise wrap_error(e)
|
||||
|
||||
return resp
|
||||
|
@ -1,144 +1,144 @@
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError
|
||||
|
||||
|
||||
class ClientSDKRequestError(MaasException):
|
||||
class ClientSDKRequestError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class SignatureDoesNotMatch(MaasException):
|
||||
class SignatureDoesNotMatchError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class RequestTimeout(MaasException):
|
||||
class RequestTimeoutError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceConnectionTimeout(MaasException):
|
||||
class ServiceConnectionTimeoutError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingAuthenticationHeader(MaasException):
|
||||
class MissingAuthenticationHeaderError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationHeaderIsInvalid(MaasException):
|
||||
class AuthenticationHeaderIsInvalidError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class InternalServiceError(MaasException):
|
||||
class InternalServiceError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingParameter(MaasException):
|
||||
class MissingParameterError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidParameter(MaasException):
|
||||
class InvalidParameterError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationExpire(MaasException):
|
||||
class AuthenticationExpireError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointIsInvalid(MaasException):
|
||||
class EndpointIsInvalidError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointIsNotEnable(MaasException):
|
||||
class EndpointIsNotEnableError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotSupportStreamMode(MaasException):
|
||||
class ModelNotSupportStreamModeError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ReqTextExistRisk(MaasException):
|
||||
class ReqTextExistRiskError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class RespTextExistRisk(MaasException):
|
||||
class RespTextExistRiskError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointRateLimitExceeded(MaasException):
|
||||
class EndpointRateLimitExceededError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceConnectionRefused(MaasException):
|
||||
class ServiceConnectionRefusedError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceConnectionClosed(MaasException):
|
||||
class ServiceConnectionClosedError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class UnauthorizedUserForEndpoint(MaasException):
|
||||
class UnauthorizedUserForEndpointError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEndpointWithNoURL(MaasException):
|
||||
class InvalidEndpointWithNoURLError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointAccountRpmRateLimitExceeded(MaasException):
|
||||
class EndpointAccountRpmRateLimitExceededError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointAccountTpmRateLimitExceeded(MaasException):
|
||||
class EndpointAccountTpmRateLimitExceededError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceResourceWaitQueueFull(MaasException):
|
||||
class ServiceResourceWaitQueueFullError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointIsPending(MaasException):
|
||||
class EndpointIsPendingError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceNotOpen(MaasException):
|
||||
class ServiceNotOpenError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
AuthErrors = {
|
||||
"SignatureDoesNotMatch": SignatureDoesNotMatch,
|
||||
"MissingAuthenticationHeader": MissingAuthenticationHeader,
|
||||
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid,
|
||||
"AuthenticationExpire": AuthenticationExpire,
|
||||
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint,
|
||||
"SignatureDoesNotMatch": SignatureDoesNotMatchError,
|
||||
"MissingAuthenticationHeader": MissingAuthenticationHeaderError,
|
||||
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError,
|
||||
"AuthenticationExpire": AuthenticationExpireError,
|
||||
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError,
|
||||
}
|
||||
|
||||
BadRequestErrors = {
|
||||
"MissingParameter": MissingParameter,
|
||||
"InvalidParameter": InvalidParameter,
|
||||
"EndpointIsInvalid": EndpointIsInvalid,
|
||||
"EndpointIsNotEnable": EndpointIsNotEnable,
|
||||
"ModelNotSupportStreamMode": ModelNotSupportStreamMode,
|
||||
"ReqTextExistRisk": ReqTextExistRisk,
|
||||
"RespTextExistRisk": RespTextExistRisk,
|
||||
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURL,
|
||||
"ServiceNotOpen": ServiceNotOpen,
|
||||
"MissingParameter": MissingParameterError,
|
||||
"InvalidParameter": InvalidParameterError,
|
||||
"EndpointIsInvalid": EndpointIsInvalidError,
|
||||
"EndpointIsNotEnable": EndpointIsNotEnableError,
|
||||
"ModelNotSupportStreamMode": ModelNotSupportStreamModeError,
|
||||
"ReqTextExistRisk": ReqTextExistRiskError,
|
||||
"RespTextExistRisk": RespTextExistRiskError,
|
||||
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError,
|
||||
"ServiceNotOpen": ServiceNotOpenError,
|
||||
}
|
||||
|
||||
RateLimitErrors = {
|
||||
"EndpointRateLimitExceeded": EndpointRateLimitExceeded,
|
||||
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded,
|
||||
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded,
|
||||
"EndpointRateLimitExceeded": EndpointRateLimitExceededError,
|
||||
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError,
|
||||
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError,
|
||||
}
|
||||
|
||||
ServerUnavailableErrors = {
|
||||
"InternalServiceError": InternalServiceError,
|
||||
"EndpointIsPending": EndpointIsPending,
|
||||
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull,
|
||||
"EndpointIsPending": EndpointIsPendingError,
|
||||
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError,
|
||||
}
|
||||
|
||||
ConnectionErrors = {
|
||||
"ClientSDKRequestError": ClientSDKRequestError,
|
||||
"RequestTimeout": RequestTimeout,
|
||||
"ServiceConnectionTimeout": ServiceConnectionTimeout,
|
||||
"ServiceConnectionRefused": ServiceConnectionRefused,
|
||||
"ServiceConnectionClosed": ServiceConnectionClosed,
|
||||
"RequestTimeout": RequestTimeoutError,
|
||||
"ServiceConnectionTimeout": ServiceConnectionTimeoutError,
|
||||
"ServiceConnectionRefused": ServiceConnectionRefusedError,
|
||||
"ServiceConnectionClosed": ServiceConnectionClosedError,
|
||||
}
|
||||
|
||||
ErrorCodeMap = {
|
||||
@ -150,7 +150,7 @@ ErrorCodeMap = {
|
||||
}
|
||||
|
||||
|
||||
def wrap_error(e: MaasException) -> Exception:
|
||||
def wrap_error(e: MaasError) -> Exception:
|
||||
if ErrorCodeMap.get(e.code):
|
||||
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
|
||||
return e
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .common import ChatRole
|
||||
from .maas import MaasException, MaasService
|
||||
from .maas import MaasError, MaasService
|
||||
|
||||
__all__ = ["MaasService", "ChatRole", "MaasException"]
|
||||
__all__ = ["MaasService", "ChatRole", "MaasError"]
|
||||
|
@ -74,7 +74,7 @@ class Signer:
|
||||
def sign(request, credentials):
|
||||
if request.path == "":
|
||||
request.path = "/"
|
||||
if request.method != "GET" and not ("Content-Type" in request.headers):
|
||||
if request.method != "GET" and "Content-Type" not in request.headers:
|
||||
request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8"
|
||||
|
||||
format_date = Signer.get_current_format_date()
|
||||
|
@ -31,7 +31,7 @@ class Service:
|
||||
self.service_info.scheme = scheme
|
||||
|
||||
def get(self, api, params, doseq=0):
|
||||
if not (api in self.api_info):
|
||||
if api not in self.api_info:
|
||||
raise Exception("no such api")
|
||||
api_info = self.api_info[api]
|
||||
|
||||
@ -49,7 +49,7 @@ class Service:
|
||||
raise Exception(resp.text)
|
||||
|
||||
def post(self, api, params, form):
|
||||
if not (api in self.api_info):
|
||||
if api not in self.api_info:
|
||||
raise Exception("no such api")
|
||||
api_info = self.api_info[api]
|
||||
r = self.prepare_request(api_info, params)
|
||||
@ -72,7 +72,7 @@ class Service:
|
||||
raise Exception(resp.text)
|
||||
|
||||
def json(self, api, params, body):
|
||||
if not (api in self.api_info):
|
||||
if api not in self.api_info:
|
||||
raise Exception("no such api")
|
||||
api_info = self.api_info[api]
|
||||
r = self.prepare_request(api_info, params)
|
||||
|
@ -63,7 +63,7 @@ class MaasService(Service):
|
||||
raise
|
||||
|
||||
if res.error is not None and res.error.code_n != 0:
|
||||
raise MaasException(
|
||||
raise MaasError(
|
||||
res.error.code_n,
|
||||
res.error.code,
|
||||
res.error.message,
|
||||
@ -72,7 +72,7 @@ class MaasService(Service):
|
||||
yield res
|
||||
|
||||
return iter_fn()
|
||||
except MaasException:
|
||||
except MaasError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise new_client_sdk_request_error(str(e))
|
||||
@ -94,7 +94,7 @@ class MaasService(Service):
|
||||
resp["req_id"] = req_id
|
||||
return resp
|
||||
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise new_client_sdk_request_error(str(e), req_id)
|
||||
@ -109,7 +109,7 @@ class MaasService(Service):
|
||||
if not self._apikey and not credentials_exist:
|
||||
raise new_client_sdk_request_error("no valid credential", req_id)
|
||||
|
||||
if not (api in self.api_info):
|
||||
if api not in self.api_info:
|
||||
raise new_client_sdk_request_error("no such api", req_id)
|
||||
|
||||
def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
|
||||
@ -147,14 +147,14 @@ class MaasService(Service):
|
||||
raise new_client_sdk_request_error(raw, req_id)
|
||||
|
||||
if resp.error:
|
||||
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id)
|
||||
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id)
|
||||
else:
|
||||
raise new_client_sdk_request_error(resp, req_id)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class MaasException(Exception):
|
||||
class MaasError(Exception):
|
||||
def __init__(self, code_n, code, message, req_id):
|
||||
self.code_n = code_n
|
||||
self.code = code
|
||||
@ -172,7 +172,7 @@ class MaasException(Exception):
|
||||
|
||||
|
||||
def new_client_sdk_request_error(raw, req_id=""):
|
||||
return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
|
||||
return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
|
||||
|
||||
|
||||
class BinaryResponseContent:
|
||||
@ -192,7 +192,7 @@ class BinaryResponseContent:
|
||||
|
||||
if len(error_bytes) > 0:
|
||||
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
|
||||
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
|
||||
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
|
||||
|
||||
def iter_bytes(self) -> Iterator[bytes]:
|
||||
yield from self.response
|
||||
|
@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasException,
|
||||
MaasError,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
},
|
||||
[UserPromptMessage(content="ping\nAnswer: ")],
|
||||
)
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
@staticmethod
|
||||
|
@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasException,
|
||||
MaasError,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
||||
|
@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
@ -42,7 +42,7 @@ class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
class InsufficientAccountBalanceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
text = ""
|
||||
for item in message:
|
||||
if isinstance(item, UserPromptMessage):
|
||||
text += item.content
|
||||
elif isinstance(item, SystemPromptMessage):
|
||||
text += item.content
|
||||
elif isinstance(item, AssistantPromptMessage):
|
||||
if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage):
|
||||
text += item.content
|
||||
else:
|
||||
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
||||
|
@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
):
|
||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||
else:
|
||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
||||
if (
|
||||
copy_prompt_message.role == PromptMessageRole.USER
|
||||
or copy_prompt_message.role == PromptMessageRole.TOOL
|
||||
):
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||
@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
@ -111,5 +111,5 @@ class Moderation(Extensible, ABC):
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
class ModerationException(Exception):
|
||||
class ModerationError(Exception):
|
||||
pass
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
@ -61,7 +61,7 @@ class InputModeration:
|
||||
return False, inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
raise ModerationError(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
@ -56,14 +56,7 @@ class KeywordsModeration(Moderation):
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
for value in inputs.values():
|
||||
if self._check_keywords_in_value(keywords_list, value):
|
||||
return True
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
return False
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list, value):
|
||||
for keyword in keywords_list:
|
||||
if keyword.lower() in value.lower():
|
||||
return True
|
||||
return False
|
||||
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||
|
@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig):
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.langfuse.com"
|
||||
@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.smith.langchain.com"
|
||||
|
@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel):
|
||||
metadata: dict[str, Any]
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_type(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
|
@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel):
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel):
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -196,6 +198,7 @@ class GenerationUsage(BaseModel):
|
||||
totalCost: Optional[float] = None
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("start_time", "end_time")
|
||||
def format_time(cls, v, info: ValidationInfo):
|
||||
if not isinstance(v, datetime):
|
||||
|
@ -223,7 +223,7 @@ class OpsTraceManager:
|
||||
:return:
|
||||
"""
|
||||
# auth check
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
|
||||
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
||||
|
@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel):
|
||||
password: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
|
@ -28,6 +28,7 @@ class MilvusConfig(BaseModel):
|
||||
database: str = "default"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
|
@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel):
|
||||
secure: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get("host"):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
|
@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ORACLE_HOST is required")
|
||||
|
@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
|
@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
|
@ -34,6 +34,7 @@ class RelytConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
@ -126,27 +127,26 @@ class RelytVector(BaseVector):
|
||||
)
|
||||
|
||||
chunks_table_data = []
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"embedding": embedding,
|
||||
"document": document,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
with self.client.connect() as conn, conn.begin():
|
||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"embedding": embedding,
|
||||
"document": document,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == 500:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == 500:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
|
||||
return ids
|
||||
|
||||
@ -185,11 +185,10 @@ class RelytVector(BaseVector):
|
||||
)
|
||||
|
||||
try:
|
||||
with self.client.connect() as conn:
|
||||
with conn.begin():
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
conn.execute(chunks_table.delete().where(delete_condition))
|
||||
return True
|
||||
with self.client.connect() as conn, conn.begin():
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
conn.execute(chunks_table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
return False
|
||||
|
@ -63,10 +63,7 @@ class TencentVector(BaseVector):
|
||||
|
||||
def _has_collection(self) -> bool:
|
||||
collections = self._db.list_collections()
|
||||
for collection in collections:
|
||||
if collection.collection_name == self._collection_name:
|
||||
return True
|
||||
return False
|
||||
return any(collection.collection_name == self._collection_name for collection in collections)
|
||||
|
||||
def _create_collection(self, dimension: int) -> None:
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
|
@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel):
|
||||
program_name: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
@ -123,20 +124,19 @@ class TiDBVector(BaseVector):
|
||||
texts = [d.page_content for d in documents]
|
||||
|
||||
chunks_table_data = []
|
||||
with self._engine.connect() as conn:
|
||||
with conn.begin():
|
||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||
with self._engine.connect() as conn, conn.begin():
|
||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == 500:
|
||||
conn.execute(insert(table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == 500:
|
||||
conn.execute(insert(table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
conn.execute(insert(table).values(chunks_table_data))
|
||||
return ids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
@ -159,11 +159,10 @@ class TiDBVector(BaseVector):
|
||||
raise ValueError("No ids provided to delete.")
|
||||
table = self._table(self._dimension)
|
||||
try:
|
||||
with self._engine.connect() as conn:
|
||||
with conn.begin():
|
||||
delete_condition = table.c.id.in_(ids)
|
||||
conn.execute(table.delete().where(delete_condition))
|
||||
return True
|
||||
with self._engine.connect() as conn, conn.begin():
|
||||
delete_condition = table.c.id.in_(ids)
|
||||
conn.execute(table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
return False
|
||||
|
@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel):
|
||||
batch_size: int = 100
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
|
@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor):
|
||||
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
||||
|
||||
self.web_path = self.file_path
|
||||
self.temp_file = tempfile.NamedTemporaryFile()
|
||||
# TODO: use a better way to handle the file
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
|
||||
self.temp_file.write(r.content)
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
|
@ -120,8 +120,8 @@ class WeightRerankRunner:
|
||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||
|
||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
||||
sum1 = sum(vec1[x] ** 2 for x in vec1)
|
||||
sum2 = sum(vec2[x] ** 2 for x in vec2)
|
||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||
|
||||
if not denominator:
|
||||
|
@ -581,8 +581,8 @@ class DatasetRetrieval:
|
||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||
|
||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
||||
sum1 = sum(vec1[x] ** 2 for x in vec1)
|
||||
sum2 = sum(vec2[x] ** 2 for x in vec2)
|
||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||
|
||||
if not denominator:
|
||||
|
@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
||||
|
@ -71,7 +71,7 @@ class BingSearchTool(BuiltinTool):
|
||||
text = ""
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
text += "\nComputation:\n"
|
||||
|
@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
|
||||
for image in response.data:
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
@ -83,5 +83,5 @@ class DIDApp:
|
||||
if status["status"] == "done":
|
||||
return status
|
||||
elif status["status"] == "error" or status["status"] == "rejected":
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}')
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
|
||||
time.sleep(poll_interval)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import urllib.parse
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
@ -13,13 +14,14 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
project = tool_parameters.get("project", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
employee = tool_parameters.get("employee", "")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
change_type = tool_parameters.get("change_type", "all")
|
||||
|
||||
if not project:
|
||||
return self.create_text_message("Project is required")
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
|
||||
if not start_time:
|
||||
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
|
||||
@ -35,91 +37,105 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
|
||||
if repository:
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
|
||||
)
|
||||
else:
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
|
||||
)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def fetch(
|
||||
def fetch_commits(
|
||||
self,
|
||||
user_id: str,
|
||||
site_url: str,
|
||||
access_token: str,
|
||||
project: str,
|
||||
employee: str = None,
|
||||
start_time: str = "",
|
||||
end_time: str = "",
|
||||
change_type: str = "",
|
||||
identifier: str,
|
||||
employee: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
change_type: str,
|
||||
is_repository: bool,
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Get all of projects
|
||||
url = f"{domain}/api/v4/projects"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
|
||||
else:
|
||||
# Get all projects
|
||||
url = f"{domain}/api/v4/projects"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
|
||||
filtered_projects = [p for p in projects if project == "*" or p["name"] == project]
|
||||
filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
|
||||
|
||||
for project in filtered_projects:
|
||||
project_id = project["id"]
|
||||
project_name = project["name"]
|
||||
print(f"Project: {project_name}")
|
||||
for project in filtered_projects:
|
||||
project_id = project["id"]
|
||||
project_name = project["name"]
|
||||
print(f"Project: {project_name}")
|
||||
|
||||
# Get all of project commits
|
||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||
params = {"since": start_time, "until": end_time}
|
||||
if employee:
|
||||
params["author"] = employee
|
||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||
|
||||
commits_response = requests.get(commits_url, headers=headers, params=params)
|
||||
commits_response.raise_for_status()
|
||||
commits = commits_response.json()
|
||||
params = {"since": start_time, "until": end_time}
|
||||
if employee:
|
||||
params["author"] = employee
|
||||
|
||||
for commit in commits:
|
||||
commit_sha = commit["id"]
|
||||
author_name = commit["author_name"]
|
||||
commits_response = requests.get(commits_url, headers=headers, params=params)
|
||||
commits_response.raise_for_status()
|
||||
commits = commits_response.json()
|
||||
|
||||
for commit in commits:
|
||||
commit_sha = commit["id"]
|
||||
author_name = commit["author_name"]
|
||||
|
||||
if is_repository:
|
||||
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
|
||||
else:
|
||||
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
diff_response.raise_for_status()
|
||||
diffs = diff_response.json()
|
||||
|
||||
for diff in diffs:
|
||||
# Calculate code lines of changed
|
||||
added_lines = diff["diff"].count("\n+")
|
||||
removed_lines = diff["diff"].count("\n-")
|
||||
total_changes = added_lines + removed_lines
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
diff_response.raise_for_status()
|
||||
diffs = diff_response.json()
|
||||
|
||||
if change_type == "new":
|
||||
if added_lines > 1:
|
||||
final_code = "".join(
|
||||
[
|
||||
line[1:]
|
||||
for line in diff["diff"].split("\n")
|
||||
if line.startswith("+") and not line.startswith("+++")
|
||||
]
|
||||
)
|
||||
results.append(
|
||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}
|
||||
)
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = "".join(
|
||||
[
|
||||
line[1:]
|
||||
for line in diff["diff"].split("\n")
|
||||
if (line.startswith("+") or line.startswith("-"))
|
||||
and not line.startswith("+++")
|
||||
and not line.startswith("---")
|
||||
]
|
||||
)
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append(
|
||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
|
||||
)
|
||||
for diff in diffs:
|
||||
# Calculate code lines of changes
|
||||
added_lines = diff["diff"].count("\n+")
|
||||
removed_lines = diff["diff"].count("\n-")
|
||||
total_changes = added_lines + removed_lines
|
||||
|
||||
if change_type == "new":
|
||||
if added_lines > 1:
|
||||
final_code = "".join(
|
||||
[
|
||||
line[1:]
|
||||
for line in diff["diff"].split("\n")
|
||||
if line.startswith("+") and not line.startswith("+++")
|
||||
]
|
||||
)
|
||||
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = "".join(
|
||||
[
|
||||
line[1:]
|
||||
for line in diff["diff"].split("\n")
|
||||
if (line.startswith("+") or line.startswith("-"))
|
||||
and not line.startswith("+++")
|
||||
and not line.startswith("---")
|
||||
]
|
||||
)
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append(
|
||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
|
@ -21,9 +21,20 @@ parameters:
|
||||
zh_Hans: 员工用户名
|
||||
llm_description: User name for GitLab
|
||||
form: llm
|
||||
- name: repository
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径
|
||||
human_description:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||
form: llm
|
||||
- name: project
|
||||
type: string
|
||||
required: true
|
||||
required: false
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目名
|
||||
|
@ -1,3 +1,4 @@
|
||||
import urllib.parse
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
@ -11,14 +12,14 @@ class GitlabFilesTool(BuiltinTool):
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
project = tool_parameters.get("project", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
branch = tool_parameters.get("branch", "")
|
||||
path = tool_parameters.get("path", "")
|
||||
|
||||
if not project:
|
||||
return self.create_text_message("Project is required")
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
if not branch:
|
||||
return self.create_text_message("Branch is required")
|
||||
|
||||
if not path:
|
||||
return self.create_text_message("Path is required")
|
||||
|
||||
@ -30,21 +31,59 @@ class GitlabFilesTool(BuiltinTool):
|
||||
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, project)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{project}' not found.")
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
|
||||
# Get file content
|
||||
if repository:
|
||||
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
|
||||
else:
|
||||
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
|
||||
parts = path.split("/", 1)
|
||||
if len(parts) < 2:
|
||||
return None, None
|
||||
return parts[0], parts[1]
|
||||
def fetch_files(
|
||||
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
|
||||
else:
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, identifier)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{identifier}' not found.")
|
||||
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
|
||||
response = requests.get(tree_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
|
||||
for item in items:
|
||||
item_path = item["path"]
|
||||
if item["type"] == "tree": # It's a directory
|
||||
results.extend(
|
||||
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
|
||||
)
|
||||
else: # It's a file
|
||||
if is_repository:
|
||||
file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}"
|
||||
else:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
)
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({"path": item_path, "branch": branch, "content": file_content})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
@ -59,32 +98,3 @@ class GitlabFilesTool(BuiltinTool):
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching project ID from GitLab: {e}")
|
||||
return None
|
||||
|
||||
def fetch(
|
||||
self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
# List files and directories in the given path
|
||||
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
|
||||
for item in items:
|
||||
item_path = item["path"]
|
||||
if item["type"] == "tree": # It's a directory
|
||||
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
|
||||
else: # It's a file
|
||||
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({"path": item_path, "branch": branch, "content": file_content})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
||||
|
@ -10,9 +10,20 @@ description:
|
||||
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
|
||||
llm: A tool for query GitLab files, Input should be a exists file or directory path.
|
||||
parameters:
|
||||
- name: repository
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径
|
||||
human_description:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||
form: llm
|
||||
- name: project
|
||||
type: string
|
||||
required: true
|
||||
required: false
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
|
@ -142,7 +142,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
||||
for control in controls:
|
||||
control_type_id = self.get_real_type_id(control)
|
||||
if (control_type_id in self._get_ignore_types()) or (
|
||||
allow_fields and not control["controlId"] in allow_fields
|
||||
allow_fields and control["controlId"] not in allow_fields
|
||||
):
|
||||
continue
|
||||
else:
|
||||
@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
||||
elif value.startswith('[{"organizeId"'):
|
||||
value = json.loads(value)
|
||||
value = "、".join([item["organizeName"] for item in value])
|
||||
elif value.startswith('[{"file_id"'):
|
||||
value = ""
|
||||
elif value == "[]":
|
||||
elif value.startswith('[{"file_id"') or value == "[]":
|
||||
value = ""
|
||||
elif hasattr(value, "accountId"):
|
||||
value = value["fullname"]
|
||||
|
@ -67,7 +67,7 @@ class ListWorksheetsTool(BuiltinTool):
|
||||
items = []
|
||||
tables = ""
|
||||
for item in section.get("items", []):
|
||||
if item.get("type") == 0 and (not "notes" in item or item.get("notes") != "NO"):
|
||||
if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"):
|
||||
if type == "json":
|
||||
filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")}
|
||||
items.append(filtered_item)
|
||||
|
@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(client_result.image_file),
|
||||
meta={"mime_type": f"image/{client_result.image_type}"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
|
||||
models_data=[],
|
||||
headers=headers,
|
||||
params=params,
|
||||
recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
|
||||
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
|
||||
)
|
||||
|
||||
result_str = ""
|
||||
|
@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image_encoded),
|
||||
meta={"mime_type": f"image/{image.image_type}"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -39,14 +39,14 @@ class QRCodeGeneratorTool(BuiltinTool):
|
||||
|
||||
# get error_correction
|
||||
error_correction = tool_parameters.get("error_correction", "")
|
||||
if error_correction not in self.error_correction_levels.keys():
|
||||
if error_correction not in self.error_correction_levels:
|
||||
return self.create_text_message("Invalid parameter error_correction")
|
||||
|
||||
try:
|
||||
image = self._generate_qrcode(content, border, error_correction)
|
||||
image_bytes = self._image_to_byte_array(image)
|
||||
return self.create_blob_message(
|
||||
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to generate QR code for content: {content}")
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user