mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 01:45:57 +08:00
Merge branch 'main' into feat/attachments
This commit is contained in:
commit
323a835de9
@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
current_key_count = (
|
current_key_count = (
|
||||||
|
@ -20,7 +20,7 @@ from fields.conversation_fields import (
|
|||||||
conversation_pagination_fields,
|
conversation_pagination_fields,
|
||||||
conversation_with_summary_pagination_fields,
|
conversation_with_summary_pagination_fields,
|
||||||
)
|
)
|
||||||
from libs.helper import datetime_string
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||||
|
|
||||||
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||||
)
|
)
|
||||||
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
|
|||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||||
)
|
)
|
||||||
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import datetime_string
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||||
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import datetime_string
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from models.workflow import WorkflowRunTriggeredFrom
|
from models.workflow import WorkflowRunTriggeredFrom
|
||||||
@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||||||
account = current_user
|
account = current_user
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
|
@ -8,7 +8,7 @@ from constants.languages import supported_language
|
|||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.error import AlreadyActivateError
|
from controllers.console.error import AlreadyActivateError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, str_len, timezone
|
from libs.helper import StrLen, email, timezone
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models.account import AccountStatus
|
from models.account import AccountStatus
|
||||||
from services.account_service import RegisterService
|
from services.account_service import RegisterService
|
||||||
@ -37,7 +37,7 @@ class ActivateApi(Resource):
|
|||||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||||
|
@ -4,7 +4,7 @@ from flask import session
|
|||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import str_len
|
from libs.helper import StrLen
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
|
|||||||
raise AlreadySetupError()
|
raise AlreadySetupError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
||||||
input_password = parser.parse_args()["password"]
|
input_password = parser.parse_args()["password"]
|
||||||
|
|
||||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||||
|
@ -4,7 +4,7 @@ from flask import request
|
|||||||
from flask_restful import Resource, reqparse
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import email, get_remote_ip, str_len
|
from libs.helper import StrLen, email, get_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
@ -40,7 +40,7 @@ class SetupApi(Resource):
|
|||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
|||||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
runner.run()
|
runner.run()
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedError:
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(
|
queue_manager.publish_error(
|
||||||
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
return generate_task_pipeline.process()
|
return generate_task_pipeline.process()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
raise GenerateTaskStoppedException()
|
raise GenerateTaskStoppedError()
|
||||||
else:
|
else:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -21,7 +21,7 @@ class AudioTrunk:
|
|||||||
self.status = status
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||||
if not text_content or text_content.isspace():
|
if not text_content or text_content.isspace():
|
||||||
return
|
return
|
||||||
return model_instance.invoke_tts(
|
return model_instance.invoke_tts(
|
||||||
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
|
|||||||
if message is None:
|
if message is None:
|
||||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||||
futures_result = self.executor.submit(
|
futures_result = self.executor.submit(
|
||||||
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||||
)
|
)
|
||||||
future_queue.put(futures_result)
|
future_queue.put(futures_result)
|
||||||
break
|
break
|
||||||
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
|
|||||||
self.MAX_SENTENCE += 1
|
self.MAX_SENTENCE += 1
|
||||||
text_content = "".join(sentence_arr)
|
text_content = "".join(sentence_arr)
|
||||||
futures_result = self.executor.submit(
|
futures_result = self.executor.submit(
|
||||||
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
|
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||||
)
|
)
|
||||||
future_queue.put(futures_result)
|
future_queue.put(futures_result)
|
||||||
if text_tmp:
|
if text_tmp:
|
||||||
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
|
|||||||
break
|
break
|
||||||
future_queue.put(None)
|
future_queue.put(None)
|
||||||
|
|
||||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||||
try:
|
try:
|
||||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||||
if self.executor:
|
if self.executor:
|
||||||
|
@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
)
|
)
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationError
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
from core.workflow.entities.node_entities import UserFrom
|
from core.workflow.entities.node_entities import UserFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
query=query,
|
query=query,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
except ModerationException as e:
|
except ModerationError as e:
|
||||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
stream_response=stream_response,
|
stream_response=stream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _listenAudioMsg(self, publisher, task_id: str):
|
def _listen_audio_msg(self, publisher, task_id: str):
|
||||||
if not publisher:
|
if not publisher:
|
||||||
return None
|
return None
|
||||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||||
if audio_msg and audio_msg.status != "finish":
|
if audio_msg and audio_msg.status != "finish":
|
||||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
return None
|
return None
|
||||||
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||||
while True:
|
while True:
|
||||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||||
if audio_response:
|
if audio_response:
|
||||||
yield audio_response
|
yield audio_response
|
||||||
else:
|
else:
|
||||||
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
try:
|
try:
|
||||||
if not tts_publisher:
|
if not tts_publisher:
|
||||||
break
|
break
|
||||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
audio_trunk = tts_publisher.check_and_get_audio()
|
||||||
if audio_trunk is None:
|
if audio_trunk is None:
|
||||||
# release cpu
|
# release cpu
|
||||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||||
|
@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
|||||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||||
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedError:
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(
|
queue_manager.publish_error(
|
||||||
|
@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
|
|||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationError
|
||||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Conversation, Message, MessageAgentThought
|
from models.model import App, Conversation, Message, MessageAgentThought
|
||||||
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
|
|||||||
query=query,
|
query=query,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
except ModerationException as e:
|
except ModerationError as e:
|
||||||
self.direct_output(
|
self.direct_output(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
app_generate_entity=application_generate_entity,
|
app_generate_entity=application_generate_entity,
|
||||||
|
@ -171,5 +171,5 @@ class AppQueueManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GenerateTaskStoppedException(Exception):
|
class GenerateTaskStoppedError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||||
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
message=message,
|
message=message,
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedError:
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(
|
queue_manager.publish_error(
|
||||||
|
@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
|||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationError
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Conversation, Message
|
from models.model import App, Conversation, Message
|
||||||
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
|
|||||||
query=query,
|
query=query,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
except ModerationException as e:
|
except ModerationError as e:
|
||||||
self.direct_output(
|
self.direct_output(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
app_generate_entity=application_generate_entity,
|
app_generate_entity=application_generate_entity,
|
||||||
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||||
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
message=message,
|
message=message,
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedError:
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(
|
queue_manager.publish_error(
|
||||||
|
@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
)
|
)
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationError
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Message
|
from models.model import App, Message
|
||||||
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
|
|||||||
query=query,
|
query=query,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
except ModerationException as e:
|
except ModerationError as e:
|
||||||
self.direct_output(
|
self.direct_output(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
app_generate_entity=application_generate_entity,
|
app_generate_entity=application_generate_entity,
|
||||||
|
@ -8,7 +8,7 @@ from sqlalchemy import and_
|
|||||||
|
|
||||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
AgentChatAppGenerateEntity,
|
AgentChatAppGenerateEntity,
|
||||||
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
return generate_task_pipeline.process()
|
return generate_task_pipeline.process()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
raise GenerateTaskStoppedException()
|
raise GenerateTaskStoppedError()
|
||||||
else:
|
else:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
raise GenerateTaskStoppedException()
|
raise GenerateTaskStoppedError()
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
|||||||
import contexts
|
import contexts
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||||
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
runner.run()
|
runner.run()
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedError:
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(
|
queue_manager.publish_error(
|
||||||
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
return generate_task_pipeline.process()
|
return generate_task_pipeline.process()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
raise GenerateTaskStoppedException()
|
raise GenerateTaskStoppedError()
|
||||||
else:
|
else:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
|
|||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
raise GenerateTaskStoppedException()
|
raise GenerateTaskStoppedError()
|
||||||
|
@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
|
|
||||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||||
|
|
||||||
def _listenAudioMsg(self, publisher, task_id: str):
|
def _listen_audio_msg(self, publisher, task_id: str):
|
||||||
if not publisher:
|
if not publisher:
|
||||||
return None
|
return None
|
||||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||||
if audio_msg and audio_msg.status != "finish":
|
if audio_msg and audio_msg.status != "finish":
|
||||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
return None
|
return None
|
||||||
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
|
|
||||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||||
while True:
|
while True:
|
||||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||||
if audio_response:
|
if audio_response:
|
||||||
yield audio_response
|
yield audio_response
|
||||||
else:
|
else:
|
||||||
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
try:
|
try:
|
||||||
if not tts_publisher:
|
if not tts_publisher:
|
||||||
break
|
break
|
||||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
audio_trunk = tts_publisher.check_and_get_audio()
|
||||||
if audio_trunk is None:
|
if audio_trunk is None:
|
||||||
# release cpu
|
# release cpu
|
||||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||||
|
@ -15,6 +15,7 @@ class Segment(BaseModel):
|
|||||||
value: Any
|
value: Any
|
||||||
|
|
||||||
@field_validator("value_type")
|
@field_validator("value_type")
|
||||||
|
@classmethod
|
||||||
def validate_value_type(cls, value):
|
def validate_value_type(cls, value):
|
||||||
"""
|
"""
|
||||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||||
|
@ -65,7 +65,7 @@ class BasedGenerateTaskPipeline:
|
|||||||
|
|
||||||
if isinstance(e, InvokeAuthorizationError):
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
elif isinstance(e, InvokeError | ValueError):
|
||||||
err = e
|
err = e
|
||||||
else:
|
else:
|
||||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||||
|
@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
stream_response=stream_response,
|
stream_response=stream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _listenAudioMsg(self, publisher, task_id: str):
|
def _listen_audio_msg(self, publisher, task_id: str):
|
||||||
if publisher is None:
|
if publisher is None:
|
||||||
return None
|
return None
|
||||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||||
if audio_msg and audio_msg.status != "finish":
|
if audio_msg and audio_msg.status != "finish":
|
||||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||||
while True:
|
while True:
|
||||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||||
if audio_response:
|
if audio_response:
|
||||||
yield audio_response
|
yield audio_response
|
||||||
else:
|
else:
|
||||||
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||||
if publisher is None:
|
if publisher is None:
|
||||||
break
|
break
|
||||||
audio = publisher.checkAndGetAudio()
|
audio = publisher.check_and_get_audio()
|
||||||
if audio is None:
|
if audio is None:
|
||||||
# release cpu
|
# release cpu
|
||||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||||
|
@ -67,7 +67,7 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
data_source_type=item.get("data_source_type"),
|
data_source_type=item.get("data_source_type"),
|
||||||
segment_id=item.get("segment_id"),
|
segment_id=item.get("segment_id"),
|
||||||
score=item.get("score") if "score" in item else None,
|
score=item.get("score") if "score" in item else None,
|
||||||
hit_count=item.get("hit_count") if "hit_count" else None,
|
hit_count=item.get("hit_count") if "hit_count" in item else None,
|
||||||
word_count=item.get("word_count") if "word_count" in item else None,
|
word_count=item.get("word_count") if "word_count" in item else None,
|
||||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
||||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
||||||
|
@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CodeExecutionException(Exception):
|
class CodeExecutionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -86,15 +86,15 @@ class CodeExecutor:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if response.status_code == 503:
|
if response.status_code == 503:
|
||||||
raise CodeExecutionException("Code execution service is unavailable")
|
raise CodeExecutionError("Code execution service is unavailable")
|
||||||
elif response.status_code != 200:
|
elif response.status_code != 200:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
||||||
)
|
)
|
||||||
except CodeExecutionException as e:
|
except CodeExecutionError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise CodeExecutionException(
|
raise CodeExecutionError(
|
||||||
"Failed to execute code, which is likely a network issue,"
|
"Failed to execute code, which is likely a network issue,"
|
||||||
" please check if the sandbox service is running."
|
" please check if the sandbox service is running."
|
||||||
f" ( Error: {str(e)} )"
|
f" ( Error: {str(e)} )"
|
||||||
@ -103,15 +103,15 @@ class CodeExecutor:
|
|||||||
try:
|
try:
|
||||||
response = response.json()
|
response = response.json()
|
||||||
except:
|
except:
|
||||||
raise CodeExecutionException("Failed to parse response")
|
raise CodeExecutionError("Failed to parse response")
|
||||||
|
|
||||||
if (code := response.get("code")) != 0:
|
if (code := response.get("code")) != 0:
|
||||||
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||||
|
|
||||||
response = CodeExecutionResponse(**response)
|
response = CodeExecutionResponse(**response)
|
||||||
|
|
||||||
if response.data.error:
|
if response.data.error:
|
||||||
raise CodeExecutionException(response.data.error)
|
raise CodeExecutionError(response.data.error)
|
||||||
|
|
||||||
return response.data.stdout or ""
|
return response.data.stdout or ""
|
||||||
|
|
||||||
@ -126,13 +126,13 @@ class CodeExecutor:
|
|||||||
"""
|
"""
|
||||||
template_transformer = cls.code_template_transformers.get(language)
|
template_transformer = cls.code_template_transformers.get(language)
|
||||||
if not template_transformer:
|
if not template_transformer:
|
||||||
raise CodeExecutionException(f"Unsupported language {language}")
|
raise CodeExecutionError(f"Unsupported language {language}")
|
||||||
|
|
||||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = cls.execute_code(language, preload, runner)
|
response = cls.execute_code(language, preload, runner)
|
||||||
except CodeExecutionException as e:
|
except CodeExecutionError as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return template_transformer.transform_response(response)
|
return template_transformer.transform_response(response)
|
||||||
|
@ -78,8 +78,8 @@ class IndexingRunner:
|
|||||||
dataset_document=dataset_document,
|
dataset_document=dataset_document,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e.description)
|
dataset_document.error = str(e.description)
|
||||||
@ -134,8 +134,8 @@ class IndexingRunner:
|
|||||||
self._load(
|
self._load(
|
||||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||||
)
|
)
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e.description)
|
dataset_document.error = str(e.description)
|
||||||
@ -192,8 +192,8 @@ class IndexingRunner:
|
|||||||
self._load(
|
self._load(
|
||||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||||
)
|
)
|
||||||
except DocumentIsPausedException:
|
except DocumentIsPausedError:
|
||||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||||
except ProviderTokenNotInitError as e:
|
except ProviderTokenNotInitError as e:
|
||||||
dataset_document.indexing_status = "error"
|
dataset_document.indexing_status = "error"
|
||||||
dataset_document.error = str(e.description)
|
dataset_document.error = str(e.description)
|
||||||
@ -756,7 +756,7 @@ class IndexingRunner:
|
|||||||
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
||||||
result = redis_client.get(indexing_cache_key)
|
result = redis_client.get(indexing_cache_key)
|
||||||
if result:
|
if result:
|
||||||
raise DocumentIsPausedException()
|
raise DocumentIsPausedError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_document_index_status(
|
def _update_document_index_status(
|
||||||
@ -767,10 +767,10 @@ class IndexingRunner:
|
|||||||
"""
|
"""
|
||||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
raise DocumentIsPausedException()
|
raise DocumentIsPausedError()
|
||||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||||
if not document:
|
if not document:
|
||||||
raise DocumentIsDeletedPausedException()
|
raise DocumentIsDeletedPausedError()
|
||||||
|
|
||||||
update_params = {DatasetDocument.indexing_status: after_indexing_status}
|
update_params = {DatasetDocument.indexing_status: after_indexing_status}
|
||||||
|
|
||||||
@ -875,9 +875,9 @@ class IndexingRunner:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DocumentIsPausedException(Exception):
|
class DocumentIsPausedError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DocumentIsDeletedPausedException(Exception):
|
class DocumentIsDeletedPausedError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
class OutputParserException(Exception):
|
class OutputParserError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.llm_generator.output_parser.errors import OutputParserException
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.prompts import (
|
from core.llm_generator.prompts import (
|
||||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||||
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
|
|||||||
raise ValueError("Expected 'opening_statement' to be a str.")
|
raise ValueError("Expected 'opening_statement' to be a str.")
|
||||||
return parsed
|
return parsed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||||
|
@ -420,7 +420,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
index += 0
|
index += 1
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
|
||||||
|
@ -7,7 +7,7 @@ from requests import post
|
|||||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalanceError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
@ -45,7 +45,7 @@ class BaichuanModel:
|
|||||||
parameters: dict[str, Any],
|
parameters: dict[str, Any],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if model in self._model_mapping.keys():
|
if model in self._model_mapping:
|
||||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||||
# we need to rename it to res_format to get its value
|
# we need to rename it to res_format to get its value
|
||||||
if parameters.get("res_format") == "json_object":
|
if parameters.get("res_format") == "json_object":
|
||||||
@ -94,7 +94,7 @@ class BaichuanModel:
|
|||||||
timeout: int,
|
timeout: int,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> Union[Iterator, dict]:
|
) -> Union[Iterator, dict]:
|
||||||
if model in self._model_mapping.keys():
|
if model in self._model_mapping:
|
||||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||||
else:
|
else:
|
||||||
raise BadRequestError(f"Unknown model: {model}")
|
raise BadRequestError(f"Unknown model: {model}")
|
||||||
@ -124,7 +124,7 @@ class BaichuanModel:
|
|||||||
if err == "invalid_api_key":
|
if err == "invalid_api_key":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
elif err == "insufficient_quota":
|
elif err == "insufficient_quota":
|
||||||
raise InsufficientAccountBalance(msg)
|
raise InsufficientAccountBalanceError(msg)
|
||||||
elif err == "invalid_authentication":
|
elif err == "invalid_authentication":
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
elif err == "invalid_request_error":
|
elif err == "invalid_request_error":
|
||||||
|
@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InsufficientAccountBalance(Exception):
|
class InsufficientAccountBalanceError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
|
|||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalanceError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
|||||||
InvokeRateLimitError: [RateLimitReachedError],
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalanceError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
|
@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
|||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalanceError,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
if err == "invalid_api_key":
|
if err == "invalid_api_key":
|
||||||
raise InvalidAPIKeyError(msg)
|
raise InvalidAPIKeyError(msg)
|
||||||
elif err == "insufficient_quota":
|
elif err == "insufficient_quota":
|
||||||
raise InsufficientAccountBalance(msg)
|
raise InsufficientAccountBalanceError(msg)
|
||||||
elif err == "invalid_authentication":
|
elif err == "invalid_authentication":
|
||||||
raise InvalidAuthenticationError(msg)
|
raise InvalidAuthenticationError(msg)
|
||||||
elif err and "rate" in err:
|
elif err and "rate" in err:
|
||||||
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
InvokeRateLimitError: [RateLimitReachedError],
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalanceError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.00025'
|
input: '0.00025'
|
||||||
output: '0.00125'
|
output: '0.00125'
|
||||||
|
@ -52,6 +52,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.015'
|
input: '0.015'
|
||||||
output: '0.075'
|
output: '0.075'
|
||||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.015'
|
output: '0.015'
|
||||||
|
@ -51,6 +51,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.003'
|
input: '0.003'
|
||||||
output: '0.015'
|
output: '0.015'
|
||||||
|
@ -45,6 +45,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.008'
|
input: '0.008'
|
||||||
output: '0.024'
|
output: '0.024'
|
||||||
|
@ -45,6 +45,8 @@ parameter_rules:
|
|||||||
help:
|
help:
|
||||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
pricing:
|
pricing:
|
||||||
input: '0.008'
|
input: '0.008'
|
||||||
output: '0.024'
|
output: '0.024'
|
||||||
|
@ -20,6 +20,7 @@ from botocore.exceptions import (
|
|||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
# local import
|
# local import
|
||||||
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -44,6 +45,14 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||||
|
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||||
|
if you are not sure about the structure.
|
||||||
|
|
||||||
|
<instructions>
|
||||||
|
{{instructions}}
|
||||||
|
</instructions>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
@ -70,6 +79,40 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||||||
logger.info(f"current model id: {model_id} did not support by Converse API")
|
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _code_block_mode_wrapper(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
callbacks: list[Callback] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
"""
|
||||||
|
Code block mode wrapper for invoking large language model
|
||||||
|
"""
|
||||||
|
if model_parameters.get("response_format"):
|
||||||
|
stop = stop or []
|
||||||
|
if "```\n" not in stop:
|
||||||
|
stop.append("```\n")
|
||||||
|
if "\n```" not in stop:
|
||||||
|
stop.append("\n```")
|
||||||
|
response_format = model_parameters.pop("response_format")
|
||||||
|
format_prompt = SystemPromptMessage(
|
||||||
|
content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
|
||||||
|
"{{block}}", response_format
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
prompt_messages[0] = format_prompt
|
||||||
|
else:
|
||||||
|
prompt_messages.insert(0, format_prompt)
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||||
|
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -337,9 +337,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -442,9 +442,7 @@ class OCILargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -77,7 +77,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
|
|||||||
inputs.append(text)
|
inputs.append(text)
|
||||||
|
|
||||||
# Prepare the payload for the request
|
# Prepare the payload for the request
|
||||||
payload = {"input": inputs, "model": model, "options": {"use_mmap": "true"}}
|
payload = {"input": inputs, "model": model, "options": {"use_mmap": True}}
|
||||||
|
|
||||||
# Make the request to the Ollama API
|
# Make the request to the Ollama API
|
||||||
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
|
||||||
|
@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _CommonOAI_API_Compat:
|
class _CommonOaiApiCompat:
|
||||||
@property
|
@property
|
||||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
"""
|
"""
|
||||||
|
@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||||
from core.model_runtime.utils import helper
|
from core.model_runtime.utils import helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
Model class for OpenAI large language model.
|
Model class for OpenAI large language model.
|
||||||
"""
|
"""
|
||||||
|
@ -6,10 +6,10 @@ import requests
|
|||||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||||
|
|
||||||
|
|
||||||
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
|
class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
|
||||||
"""
|
"""
|
||||||
Model class for OpenAI Compatible Speech to text model.
|
Model class for OpenAI Compatible Speech to text model.
|
||||||
"""
|
"""
|
||||||
|
@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||||
|
|
||||||
|
|
||||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
Model class for an OpenAI API-compatible text embedding model.
|
Model class for an OpenAI API-compatible text embedding model.
|
||||||
"""
|
"""
|
||||||
|
@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||||
|
|
||||||
|
|
||||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||||
"""
|
"""
|
||||||
Model class for an OpenAI API-compatible text embedding model.
|
Model class for an OpenAI API-compatible text embedding model.
|
||||||
"""
|
"""
|
||||||
|
@ -350,9 +350,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||||||
break
|
break
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = content
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = content
|
message_text = content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -115,7 +115,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
token = credentials.token
|
token = credentials.token
|
||||||
|
|
||||||
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region
|
# Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region
|
||||||
if "opus" or "claude-3-5-sonnet" in model:
|
if "opus" in model or "claude-3-5-sonnet" in model:
|
||||||
location = "us-east5"
|
location = "us-east5"
|
||||||
else:
|
else:
|
||||||
location = "us-central1"
|
location = "us-central1"
|
||||||
@ -633,9 +633,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
||||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
|
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService
|
||||||
|
|
||||||
|
|
||||||
class MaaSClient(MaasService):
|
class MaaSClient(MaasService):
|
||||||
@ -106,7 +106,7 @@ class MaaSClient(MaasService):
|
|||||||
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||||
try:
|
try:
|
||||||
resp = fn()
|
resp = fn()
|
||||||
except MaasException as e:
|
except MaasError as e:
|
||||||
raise wrap_error(e)
|
raise wrap_error(e)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
@ -1,144 +1,144 @@
|
|||||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
|
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError
|
||||||
|
|
||||||
|
|
||||||
class ClientSDKRequestError(MaasException):
|
class ClientSDKRequestError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SignatureDoesNotMatch(MaasException):
|
class SignatureDoesNotMatchError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RequestTimeout(MaasException):
|
class RequestTimeoutError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServiceConnectionTimeout(MaasException):
|
class ServiceConnectionTimeoutError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MissingAuthenticationHeader(MaasException):
|
class MissingAuthenticationHeaderError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationHeaderIsInvalid(MaasException):
|
class AuthenticationHeaderIsInvalidError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InternalServiceError(MaasException):
|
class InternalServiceError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MissingParameter(MaasException):
|
class MissingParameterError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidParameter(MaasException):
|
class InvalidParameterError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationExpire(MaasException):
|
class AuthenticationExpireError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EndpointIsInvalid(MaasException):
|
class EndpointIsInvalidError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EndpointIsNotEnable(MaasException):
|
class EndpointIsNotEnableError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelNotSupportStreamMode(MaasException):
|
class ModelNotSupportStreamModeError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ReqTextExistRisk(MaasException):
|
class ReqTextExistRiskError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RespTextExistRisk(MaasException):
|
class RespTextExistRiskError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EndpointRateLimitExceeded(MaasException):
|
class EndpointRateLimitExceededError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServiceConnectionRefused(MaasException):
|
class ServiceConnectionRefusedError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServiceConnectionClosed(MaasException):
|
class ServiceConnectionClosedError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnauthorizedUserForEndpoint(MaasException):
|
class UnauthorizedUserForEndpointError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidEndpointWithNoURL(MaasException):
|
class InvalidEndpointWithNoURLError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EndpointAccountRpmRateLimitExceeded(MaasException):
|
class EndpointAccountRpmRateLimitExceededError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EndpointAccountTpmRateLimitExceeded(MaasException):
|
class EndpointAccountTpmRateLimitExceededError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServiceResourceWaitQueueFull(MaasException):
|
class ServiceResourceWaitQueueFullError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EndpointIsPending(MaasException):
|
class EndpointIsPendingError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServiceNotOpen(MaasException):
|
class ServiceNotOpenError(MaasError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
AuthErrors = {
|
AuthErrors = {
|
||||||
"SignatureDoesNotMatch": SignatureDoesNotMatch,
|
"SignatureDoesNotMatch": SignatureDoesNotMatchError,
|
||||||
"MissingAuthenticationHeader": MissingAuthenticationHeader,
|
"MissingAuthenticationHeader": MissingAuthenticationHeaderError,
|
||||||
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid,
|
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError,
|
||||||
"AuthenticationExpire": AuthenticationExpire,
|
"AuthenticationExpire": AuthenticationExpireError,
|
||||||
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint,
|
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError,
|
||||||
}
|
}
|
||||||
|
|
||||||
BadRequestErrors = {
|
BadRequestErrors = {
|
||||||
"MissingParameter": MissingParameter,
|
"MissingParameter": MissingParameterError,
|
||||||
"InvalidParameter": InvalidParameter,
|
"InvalidParameter": InvalidParameterError,
|
||||||
"EndpointIsInvalid": EndpointIsInvalid,
|
"EndpointIsInvalid": EndpointIsInvalidError,
|
||||||
"EndpointIsNotEnable": EndpointIsNotEnable,
|
"EndpointIsNotEnable": EndpointIsNotEnableError,
|
||||||
"ModelNotSupportStreamMode": ModelNotSupportStreamMode,
|
"ModelNotSupportStreamMode": ModelNotSupportStreamModeError,
|
||||||
"ReqTextExistRisk": ReqTextExistRisk,
|
"ReqTextExistRisk": ReqTextExistRiskError,
|
||||||
"RespTextExistRisk": RespTextExistRisk,
|
"RespTextExistRisk": RespTextExistRiskError,
|
||||||
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURL,
|
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError,
|
||||||
"ServiceNotOpen": ServiceNotOpen,
|
"ServiceNotOpen": ServiceNotOpenError,
|
||||||
}
|
}
|
||||||
|
|
||||||
RateLimitErrors = {
|
RateLimitErrors = {
|
||||||
"EndpointRateLimitExceeded": EndpointRateLimitExceeded,
|
"EndpointRateLimitExceeded": EndpointRateLimitExceededError,
|
||||||
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded,
|
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError,
|
||||||
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded,
|
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError,
|
||||||
}
|
}
|
||||||
|
|
||||||
ServerUnavailableErrors = {
|
ServerUnavailableErrors = {
|
||||||
"InternalServiceError": InternalServiceError,
|
"InternalServiceError": InternalServiceError,
|
||||||
"EndpointIsPending": EndpointIsPending,
|
"EndpointIsPending": EndpointIsPendingError,
|
||||||
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull,
|
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError,
|
||||||
}
|
}
|
||||||
|
|
||||||
ConnectionErrors = {
|
ConnectionErrors = {
|
||||||
"ClientSDKRequestError": ClientSDKRequestError,
|
"ClientSDKRequestError": ClientSDKRequestError,
|
||||||
"RequestTimeout": RequestTimeout,
|
"RequestTimeout": RequestTimeoutError,
|
||||||
"ServiceConnectionTimeout": ServiceConnectionTimeout,
|
"ServiceConnectionTimeout": ServiceConnectionTimeoutError,
|
||||||
"ServiceConnectionRefused": ServiceConnectionRefused,
|
"ServiceConnectionRefused": ServiceConnectionRefusedError,
|
||||||
"ServiceConnectionClosed": ServiceConnectionClosed,
|
"ServiceConnectionClosed": ServiceConnectionClosedError,
|
||||||
}
|
}
|
||||||
|
|
||||||
ErrorCodeMap = {
|
ErrorCodeMap = {
|
||||||
@ -150,7 +150,7 @@ ErrorCodeMap = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def wrap_error(e: MaasException) -> Exception:
|
def wrap_error(e: MaasError) -> Exception:
|
||||||
if ErrorCodeMap.get(e.code):
|
if ErrorCodeMap.get(e.code):
|
||||||
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
|
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
|
||||||
return e
|
return e
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from .common import ChatRole
|
from .common import ChatRole
|
||||||
from .maas import MaasException, MaasService
|
from .maas import MaasError, MaasService
|
||||||
|
|
||||||
__all__ = ["MaasService", "ChatRole", "MaasException"]
|
__all__ = ["MaasService", "ChatRole", "MaasError"]
|
||||||
|
@ -74,7 +74,7 @@ class Signer:
|
|||||||
def sign(request, credentials):
|
def sign(request, credentials):
|
||||||
if request.path == "":
|
if request.path == "":
|
||||||
request.path = "/"
|
request.path = "/"
|
||||||
if request.method != "GET" and not ("Content-Type" in request.headers):
|
if request.method != "GET" and "Content-Type" not in request.headers:
|
||||||
request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8"
|
request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8"
|
||||||
|
|
||||||
format_date = Signer.get_current_format_date()
|
format_date = Signer.get_current_format_date()
|
||||||
|
@ -31,7 +31,7 @@ class Service:
|
|||||||
self.service_info.scheme = scheme
|
self.service_info.scheme = scheme
|
||||||
|
|
||||||
def get(self, api, params, doseq=0):
|
def get(self, api, params, doseq=0):
|
||||||
if not (api in self.api_info):
|
if api not in self.api_info:
|
||||||
raise Exception("no such api")
|
raise Exception("no such api")
|
||||||
api_info = self.api_info[api]
|
api_info = self.api_info[api]
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ class Service:
|
|||||||
raise Exception(resp.text)
|
raise Exception(resp.text)
|
||||||
|
|
||||||
def post(self, api, params, form):
|
def post(self, api, params, form):
|
||||||
if not (api in self.api_info):
|
if api not in self.api_info:
|
||||||
raise Exception("no such api")
|
raise Exception("no such api")
|
||||||
api_info = self.api_info[api]
|
api_info = self.api_info[api]
|
||||||
r = self.prepare_request(api_info, params)
|
r = self.prepare_request(api_info, params)
|
||||||
@ -72,7 +72,7 @@ class Service:
|
|||||||
raise Exception(resp.text)
|
raise Exception(resp.text)
|
||||||
|
|
||||||
def json(self, api, params, body):
|
def json(self, api, params, body):
|
||||||
if not (api in self.api_info):
|
if api not in self.api_info:
|
||||||
raise Exception("no such api")
|
raise Exception("no such api")
|
||||||
api_info = self.api_info[api]
|
api_info = self.api_info[api]
|
||||||
r = self.prepare_request(api_info, params)
|
r = self.prepare_request(api_info, params)
|
||||||
|
@ -63,7 +63,7 @@ class MaasService(Service):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
if res.error is not None and res.error.code_n != 0:
|
if res.error is not None and res.error.code_n != 0:
|
||||||
raise MaasException(
|
raise MaasError(
|
||||||
res.error.code_n,
|
res.error.code_n,
|
||||||
res.error.code,
|
res.error.code,
|
||||||
res.error.message,
|
res.error.message,
|
||||||
@ -72,7 +72,7 @@ class MaasService(Service):
|
|||||||
yield res
|
yield res
|
||||||
|
|
||||||
return iter_fn()
|
return iter_fn()
|
||||||
except MaasException:
|
except MaasError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise new_client_sdk_request_error(str(e))
|
raise new_client_sdk_request_error(str(e))
|
||||||
@ -94,7 +94,7 @@ class MaasService(Service):
|
|||||||
resp["req_id"] = req_id
|
resp["req_id"] = req_id
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
except MaasException as e:
|
except MaasError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise new_client_sdk_request_error(str(e), req_id)
|
raise new_client_sdk_request_error(str(e), req_id)
|
||||||
@ -109,7 +109,7 @@ class MaasService(Service):
|
|||||||
if not self._apikey and not credentials_exist:
|
if not self._apikey and not credentials_exist:
|
||||||
raise new_client_sdk_request_error("no valid credential", req_id)
|
raise new_client_sdk_request_error("no valid credential", req_id)
|
||||||
|
|
||||||
if not (api in self.api_info):
|
if api not in self.api_info:
|
||||||
raise new_client_sdk_request_error("no such api", req_id)
|
raise new_client_sdk_request_error("no such api", req_id)
|
||||||
|
|
||||||
def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
|
def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
|
||||||
@ -147,14 +147,14 @@ class MaasService(Service):
|
|||||||
raise new_client_sdk_request_error(raw, req_id)
|
raise new_client_sdk_request_error(raw, req_id)
|
||||||
|
|
||||||
if resp.error:
|
if resp.error:
|
||||||
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id)
|
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id)
|
||||||
else:
|
else:
|
||||||
raise new_client_sdk_request_error(resp, req_id)
|
raise new_client_sdk_request_error(resp, req_id)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class MaasException(Exception):
|
class MaasError(Exception):
|
||||||
def __init__(self, code_n, code, message, req_id):
|
def __init__(self, code_n, code, message, req_id):
|
||||||
self.code_n = code_n
|
self.code_n = code_n
|
||||||
self.code = code
|
self.code = code
|
||||||
@ -172,7 +172,7 @@ class MaasException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def new_client_sdk_request_error(raw, req_id=""):
|
def new_client_sdk_request_error(raw, req_id=""):
|
||||||
return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
|
return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
|
||||||
|
|
||||||
|
|
||||||
class BinaryResponseContent:
|
class BinaryResponseContent:
|
||||||
@ -192,7 +192,7 @@ class BinaryResponseContent:
|
|||||||
|
|
||||||
if len(error_bytes) > 0:
|
if len(error_bytes) > 0:
|
||||||
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
|
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
|
||||||
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
|
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
|
||||||
|
|
||||||
def iter_bytes(self) -> Iterator[bytes]:
|
def iter_bytes(self) -> Iterator[bytes]:
|
||||||
yield from self.response
|
yield from self.response
|
||||||
|
@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
|||||||
AuthErrors,
|
AuthErrors,
|
||||||
BadRequestErrors,
|
BadRequestErrors,
|
||||||
ConnectionErrors,
|
ConnectionErrors,
|
||||||
MaasException,
|
MaasError,
|
||||||
RateLimitErrors,
|
RateLimitErrors,
|
||||||
ServerUnavailableErrors,
|
ServerUnavailableErrors,
|
||||||
)
|
)
|
||||||
@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|||||||
},
|
},
|
||||||
[UserPromptMessage(content="ping\nAnswer: ")],
|
[UserPromptMessage(content="ping\nAnswer: ")],
|
||||||
)
|
)
|
||||||
except MaasException as e:
|
except MaasError as e:
|
||||||
raise CredentialsValidateFailedError(e.message)
|
raise CredentialsValidateFailedError(e.message)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
|||||||
AuthErrors,
|
AuthErrors,
|
||||||
BadRequestErrors,
|
BadRequestErrors,
|
||||||
ConnectionErrors,
|
ConnectionErrors,
|
||||||
MaasException,
|
MaasError,
|
||||||
RateLimitErrors,
|
RateLimitErrors,
|
||||||
ServerUnavailableErrors,
|
ServerUnavailableErrors,
|
||||||
)
|
)
|
||||||
@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|||||||
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
||||||
try:
|
try:
|
||||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
except MaasException as e:
|
except MaasError as e:
|
||||||
raise CredentialsValidateFailedError(e.message)
|
raise CredentialsValidateFailedError(e.message)
|
||||||
|
|
||||||
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
||||||
|
@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
|||||||
InvokeRateLimitError: [RateLimitReachedError],
|
InvokeRateLimitError: [RateLimitReachedError],
|
||||||
InvokeAuthorizationError: [
|
InvokeAuthorizationError: [
|
||||||
InvalidAuthenticationError,
|
InvalidAuthenticationError,
|
||||||
InsufficientAccountBalance,
|
InsufficientAccountBalanceError,
|
||||||
InvalidAPIKeyError,
|
InvalidAPIKeyError,
|
||||||
],
|
],
|
||||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||||
@ -42,7 +42,7 @@ class RateLimitReachedError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InsufficientAccountBalance(Exception):
|
class InsufficientAccountBalanceError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -272,11 +272,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
|||||||
"""
|
"""
|
||||||
text = ""
|
text = ""
|
||||||
for item in message:
|
for item in message:
|
||||||
if isinstance(item, UserPromptMessage):
|
if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage):
|
||||||
text += item.content
|
|
||||||
elif isinstance(item, SystemPromptMessage):
|
|
||||||
text += item.content
|
|
||||||
elif isinstance(item, AssistantPromptMessage):
|
|
||||||
text += item.content
|
text += item.content
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
raise NotImplementedError(f"PromptMessage type {type(item)} is not supported")
|
||||||
|
@ -209,9 +209,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
):
|
):
|
||||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||||
else:
|
else:
|
||||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
if (
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
copy_prompt_message.role == PromptMessageRole.USER
|
||||||
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
or copy_prompt_message.role == PromptMessageRole.TOOL
|
||||||
|
):
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||||
@ -461,9 +462,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
message_text = f"{ai_prompt} {content}"
|
message_text = f"{ai_prompt} {content}"
|
||||||
elif isinstance(message, SystemPromptMessage):
|
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
|
||||||
message_text = content
|
|
||||||
elif isinstance(message, ToolPromptMessage):
|
|
||||||
message_text = content
|
message_text = content
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
|
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
|
||||||
# inputs_config
|
# inputs_config
|
||||||
inputs_config = config.get("inputs_config")
|
inputs_config = config.get("inputs_config")
|
||||||
if not isinstance(inputs_config, dict):
|
if not isinstance(inputs_config, dict):
|
||||||
@ -111,5 +111,5 @@ class Moderation(Extensible, ABC):
|
|||||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||||
|
|
||||||
|
|
||||||
class ModerationException(Exception):
|
class ModerationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig
|
from core.app.app_config.entities import AppConfig
|
||||||
from core.moderation.base import ModerationAction, ModerationException
|
from core.moderation.base import ModerationAction, ModerationError
|
||||||
from core.moderation.factory import ModerationFactory
|
from core.moderation.factory import ModerationFactory
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
@ -61,7 +61,7 @@ class InputModeration:
|
|||||||
return False, inputs, query
|
return False, inputs, query
|
||||||
|
|
||||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||||
raise ModerationException(moderation_result.preset_response)
|
raise ModerationError(moderation_result.preset_response)
|
||||||
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
||||||
inputs = moderation_result.inputs
|
inputs = moderation_result.inputs
|
||||||
query = moderation_result.query
|
query = moderation_result.query
|
||||||
|
@ -56,14 +56,7 @@ class KeywordsModeration(Moderation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||||
for value in inputs.values():
|
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||||
if self._check_keywords_in_value(keywords_list, value):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
def _check_keywords_in_value(self, keywords_list, value) -> bool:
|
||||||
|
return any(keyword.lower() in value.lower() for keyword in keywords_list)
|
||||||
def _check_keywords_in_value(self, keywords_list, value):
|
|
||||||
for keyword in keywords_list:
|
|
||||||
if keyword.lower() in value.lower():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig):
|
|||||||
host: str = "https://api.langfuse.com"
|
host: str = "https://api.langfuse.com"
|
||||||
|
|
||||||
@field_validator("host")
|
@field_validator("host")
|
||||||
|
@classmethod
|
||||||
def set_value(cls, v, info: ValidationInfo):
|
def set_value(cls, v, info: ValidationInfo):
|
||||||
if v is None or v == "":
|
if v is None or v == "":
|
||||||
v = "https://api.langfuse.com"
|
v = "https://api.langfuse.com"
|
||||||
@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig):
|
|||||||
endpoint: str = "https://api.smith.langchain.com"
|
endpoint: str = "https://api.smith.langchain.com"
|
||||||
|
|
||||||
@field_validator("endpoint")
|
@field_validator("endpoint")
|
||||||
|
@classmethod
|
||||||
def set_value(cls, v, info: ValidationInfo):
|
def set_value(cls, v, info: ValidationInfo):
|
||||||
if v is None or v == "":
|
if v is None or v == "":
|
||||||
v = "https://api.smith.langchain.com"
|
v = "https://api.smith.langchain.com"
|
||||||
|
@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel):
|
|||||||
metadata: dict[str, Any]
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
@field_validator("inputs", "outputs")
|
@field_validator("inputs", "outputs")
|
||||||
|
@classmethod
|
||||||
def ensure_type(cls, v):
|
def ensure_type(cls, v):
|
||||||
if v is None:
|
if v is None:
|
||||||
return None
|
return None
|
||||||
|
@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("input", "output")
|
@field_validator("input", "output")
|
||||||
|
@classmethod
|
||||||
def ensure_dict(cls, v, info: ValidationInfo):
|
def ensure_dict(cls, v, info: ValidationInfo):
|
||||||
field_name = info.field_name
|
field_name = info.field_name
|
||||||
return validate_input_output(v, field_name)
|
return validate_input_output(v, field_name)
|
||||||
@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("input", "output")
|
@field_validator("input", "output")
|
||||||
|
@classmethod
|
||||||
def ensure_dict(cls, v, info: ValidationInfo):
|
def ensure_dict(cls, v, info: ValidationInfo):
|
||||||
field_name = info.field_name
|
field_name = info.field_name
|
||||||
return validate_input_output(v, field_name)
|
return validate_input_output(v, field_name)
|
||||||
@ -196,6 +198,7 @@ class GenerationUsage(BaseModel):
|
|||||||
totalCost: Optional[float] = None
|
totalCost: Optional[float] = None
|
||||||
|
|
||||||
@field_validator("input", "output")
|
@field_validator("input", "output")
|
||||||
|
@classmethod
|
||||||
def ensure_dict(cls, v, info: ValidationInfo):
|
def ensure_dict(cls, v, info: ValidationInfo):
|
||||||
field_name = info.field_name
|
field_name = info.field_name
|
||||||
return validate_input_output(v, field_name)
|
return validate_input_output(v, field_name)
|
||||||
@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
@field_validator("input", "output")
|
@field_validator("input", "output")
|
||||||
|
@classmethod
|
||||||
def ensure_dict(cls, v, info: ValidationInfo):
|
def ensure_dict(cls, v, info: ValidationInfo):
|
||||||
field_name = info.field_name
|
field_name = info.field_name
|
||||||
return validate_input_output(v, field_name)
|
return validate_input_output(v, field_name)
|
||||||
|
@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
|||||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||||
|
|
||||||
@field_validator("inputs", "outputs")
|
@field_validator("inputs", "outputs")
|
||||||
|
@classmethod
|
||||||
def ensure_dict(cls, v, info: ValidationInfo):
|
def ensure_dict(cls, v, info: ValidationInfo):
|
||||||
field_name = info.field_name
|
field_name = info.field_name
|
||||||
values = info.data
|
values = info.data
|
||||||
@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
|||||||
return v
|
return v
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
@field_validator("start_time", "end_time")
|
@field_validator("start_time", "end_time")
|
||||||
def format_time(cls, v, info: ValidationInfo):
|
def format_time(cls, v, info: ValidationInfo):
|
||||||
if not isinstance(v, datetime):
|
if not isinstance(v, datetime):
|
||||||
|
@ -223,7 +223,7 @@ class OpsTraceManager:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# auth check
|
# auth check
|
||||||
if tracing_provider not in provider_config_map.keys() and tracing_provider is not None:
|
if tracing_provider not in provider_config_map and tracing_provider is not None:
|
||||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||||
|
|
||||||
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
app_config: App = db.session.query(App).filter(App.id == app_id).first()
|
||||||
|
@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["host"]:
|
if not values["host"]:
|
||||||
raise ValueError("config HOST is required")
|
raise ValueError("config HOST is required")
|
||||||
|
@ -28,6 +28,7 @@ class MilvusConfig(BaseModel):
|
|||||||
database: str = "default"
|
database: str = "default"
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values.get("uri"):
|
if not values.get("uri"):
|
||||||
raise ValueError("config MILVUS_URI is required")
|
raise ValueError("config MILVUS_URI is required")
|
||||||
|
@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel):
|
|||||||
secure: bool = False
|
secure: bool = False
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values.get("host"):
|
if not values.get("host"):
|
||||||
raise ValueError("config OPENSEARCH_HOST is required")
|
raise ValueError("config OPENSEARCH_HOST is required")
|
||||||
|
@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel):
|
|||||||
database: str
|
database: str
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["host"]:
|
if not values["host"]:
|
||||||
raise ValueError("config ORACLE_HOST is required")
|
raise ValueError("config ORACLE_HOST is required")
|
||||||
|
@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel):
|
|||||||
database: str
|
database: str
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["host"]:
|
if not values["host"]:
|
||||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||||
|
@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel):
|
|||||||
database: str
|
database: str
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["host"]:
|
if not values["host"]:
|
||||||
raise ValueError("config PGVECTOR_HOST is required")
|
raise ValueError("config PGVECTOR_HOST is required")
|
||||||
|
@ -34,6 +34,7 @@ class RelytConfig(BaseModel):
|
|||||||
database: str
|
database: str
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["host"]:
|
if not values["host"]:
|
||||||
raise ValueError("config RELYT_HOST is required")
|
raise ValueError("config RELYT_HOST is required")
|
||||||
@ -126,27 +127,26 @@ class RelytVector(BaseVector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunks_table_data = []
|
chunks_table_data = []
|
||||||
with self.client.connect() as conn:
|
with self.client.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
||||||
for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings):
|
chunks_table_data.append(
|
||||||
chunks_table_data.append(
|
{
|
||||||
{
|
"id": chunk_id,
|
||||||
"id": chunk_id,
|
"embedding": embedding,
|
||||||
"embedding": embedding,
|
"document": document,
|
||||||
"document": document,
|
"metadata": metadata,
|
||||||
"metadata": metadata,
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Execute the batch insert when the batch size is reached
|
# Execute the batch insert when the batch size is reached
|
||||||
if len(chunks_table_data) == 500:
|
if len(chunks_table_data) == 500:
|
||||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
|
||||||
# Clear the chunks_table_data list for the next batch
|
|
||||||
chunks_table_data.clear()
|
|
||||||
|
|
||||||
# Insert any remaining records that didn't make up a full batch
|
|
||||||
if chunks_table_data:
|
|
||||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
@ -185,11 +185,10 @@ class RelytVector(BaseVector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self.client.connect() as conn:
|
with self.client.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
delete_condition = chunks_table.c.id.in_(ids)
|
||||||
delete_condition = chunks_table.c.id.in_(ids)
|
conn.execute(chunks_table.delete().where(delete_condition))
|
||||||
conn.execute(chunks_table.delete().where(delete_condition))
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Delete operation failed:", str(e))
|
print("Delete operation failed:", str(e))
|
||||||
return False
|
return False
|
||||||
|
@ -63,10 +63,7 @@ class TencentVector(BaseVector):
|
|||||||
|
|
||||||
def _has_collection(self) -> bool:
|
def _has_collection(self) -> bool:
|
||||||
collections = self._db.list_collections()
|
collections = self._db.list_collections()
|
||||||
for collection in collections:
|
return any(collection.collection_name == self._collection_name for collection in collections)
|
||||||
if collection.collection_name == self._collection_name:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _create_collection(self, dimension: int) -> None:
|
def _create_collection(self, dimension: int) -> None:
|
||||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||||
|
@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel):
|
|||||||
program_name: str
|
program_name: str
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["host"]:
|
if not values["host"]:
|
||||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||||
@ -123,20 +124,19 @@ class TiDBVector(BaseVector):
|
|||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
|
|
||||||
chunks_table_data = []
|
chunks_table_data = []
|
||||||
with self._engine.connect() as conn:
|
with self._engine.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
||||||
for id, text, meta, embedding in zip(ids, texts, metas, embeddings):
|
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
||||||
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta})
|
|
||||||
|
|
||||||
# Execute the batch insert when the batch size is reached
|
# Execute the batch insert when the batch size is reached
|
||||||
if len(chunks_table_data) == 500:
|
if len(chunks_table_data) == 500:
|
||||||
conn.execute(insert(table).values(chunks_table_data))
|
|
||||||
# Clear the chunks_table_data list for the next batch
|
|
||||||
chunks_table_data.clear()
|
|
||||||
|
|
||||||
# Insert any remaining records that didn't make up a full batch
|
|
||||||
if chunks_table_data:
|
|
||||||
conn.execute(insert(table).values(chunks_table_data))
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(table).values(chunks_table_data))
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
@ -159,11 +159,10 @@ class TiDBVector(BaseVector):
|
|||||||
raise ValueError("No ids provided to delete.")
|
raise ValueError("No ids provided to delete.")
|
||||||
table = self._table(self._dimension)
|
table = self._table(self._dimension)
|
||||||
try:
|
try:
|
||||||
with self._engine.connect() as conn:
|
with self._engine.connect() as conn, conn.begin():
|
||||||
with conn.begin():
|
delete_condition = table.c.id.in_(ids)
|
||||||
delete_condition = table.c.id.in_(ids)
|
conn.execute(table.delete().where(delete_condition))
|
||||||
conn.execute(table.delete().where(delete_condition))
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Delete operation failed:", str(e))
|
print("Delete operation failed:", str(e))
|
||||||
return False
|
return False
|
||||||
|
@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel):
|
|||||||
batch_size: int = 100
|
batch_size: int = 100
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values["endpoint"]:
|
if not values["endpoint"]:
|
||||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||||
|
@ -48,7 +48,8 @@ class WordExtractor(BaseExtractor):
|
|||||||
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
||||||
|
|
||||||
self.web_path = self.file_path
|
self.web_path = self.file_path
|
||||||
self.temp_file = tempfile.NamedTemporaryFile()
|
# TODO: use a better way to handle the file
|
||||||
|
self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115
|
||||||
self.temp_file.write(r.content)
|
self.temp_file.write(r.content)
|
||||||
self.file_path = self.temp_file.name
|
self.file_path = self.temp_file.name
|
||||||
elif not os.path.isfile(self.file_path):
|
elif not os.path.isfile(self.file_path):
|
||||||
|
@ -120,8 +120,8 @@ class WeightRerankRunner:
|
|||||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||||
|
|
||||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
sum1 = sum(vec1[x] ** 2 for x in vec1)
|
||||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
sum2 = sum(vec2[x] ** 2 for x in vec2)
|
||||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||||
|
|
||||||
if not denominator:
|
if not denominator:
|
||||||
|
@ -581,8 +581,8 @@ class DatasetRetrieval:
|
|||||||
intersection = set(vec1.keys()) & set(vec2.keys())
|
intersection = set(vec1.keys()) & set(vec2.keys())
|
||||||
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
||||||
|
|
||||||
sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
sum1 = sum(vec1[x] ** 2 for x in vec1)
|
||||||
sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
sum2 = sum(vec2[x] ** 2 for x in vec2)
|
||||||
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
||||||
|
|
||||||
if not denominator:
|
if not denominator:
|
||||||
|
@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool):
|
|||||||
self.create_blob_message(
|
self.create_blob_message(
|
||||||
blob=b64decode(image.b64_json),
|
blob=b64decode(image.b64_json),
|
||||||
meta={"mime_type": "image/png"},
|
meta={"mime_type": "image/png"},
|
||||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
save_as=self.VariableKey.IMAGE.value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
||||||
|
@ -71,7 +71,7 @@ class BingSearchTool(BuiltinTool):
|
|||||||
text = ""
|
text = ""
|
||||||
if search_results:
|
if search_results:
|
||||||
for i, result in enumerate(search_results):
|
for i, result in enumerate(search_results):
|
||||||
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||||
|
|
||||||
if computation and "expression" in computation and "value" in computation:
|
if computation and "expression" in computation and "value" in computation:
|
||||||
text += "\nComputation:\n"
|
text += "\nComputation:\n"
|
||||||
|
@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool):
|
|||||||
self.create_blob_message(
|
self.create_blob_message(
|
||||||
blob=b64decode(image.b64_json),
|
blob=b64decode(image.b64_json),
|
||||||
meta={"mime_type": "image/png"},
|
meta={"mime_type": "image/png"},
|
||||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
save_as=self.VariableKey.IMAGE.value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
|
|||||||
for image in response.data:
|
for image in response.data:
|
||||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||||
blob_message = self.create_blob_message(
|
blob_message = self.create_blob_message(
|
||||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value
|
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
||||||
)
|
)
|
||||||
result.append(blob_message)
|
result.append(blob_message)
|
||||||
return result
|
return result
|
||||||
|
@ -83,5 +83,5 @@ class DIDApp:
|
|||||||
if status["status"] == "done":
|
if status["status"] == "done":
|
||||||
return status
|
return status
|
||||||
elif status["status"] == "error" or status["status"] == "rejected":
|
elif status["status"] == "error" or status["status"] == "rejected":
|
||||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}')
|
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
|
||||||
time.sleep(poll_interval)
|
time.sleep(poll_interval)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import urllib.parse
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
@ -13,13 +14,14 @@ class GitlabCommitsTool(BuiltinTool):
|
|||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
self, user_id: str, tool_parameters: dict[str, Any]
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
project = tool_parameters.get("project", "")
|
project = tool_parameters.get("project", "")
|
||||||
|
repository = tool_parameters.get("repository", "")
|
||||||
employee = tool_parameters.get("employee", "")
|
employee = tool_parameters.get("employee", "")
|
||||||
start_time = tool_parameters.get("start_time", "")
|
start_time = tool_parameters.get("start_time", "")
|
||||||
end_time = tool_parameters.get("end_time", "")
|
end_time = tool_parameters.get("end_time", "")
|
||||||
change_type = tool_parameters.get("change_type", "all")
|
change_type = tool_parameters.get("change_type", "all")
|
||||||
|
|
||||||
if not project:
|
if not project and not repository:
|
||||||
return self.create_text_message("Project is required")
|
return self.create_text_message("Either project or repository is required")
|
||||||
|
|
||||||
if not start_time:
|
if not start_time:
|
||||||
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
|
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
|
||||||
@ -35,91 +37,105 @@ class GitlabCommitsTool(BuiltinTool):
|
|||||||
site_url = "https://gitlab.com"
|
site_url = "https://gitlab.com"
|
||||||
|
|
||||||
# Get commit content
|
# Get commit content
|
||||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
|
if repository:
|
||||||
|
result = self.fetch_commits(
|
||||||
|
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self.fetch_commits(
|
||||||
|
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
|
||||||
|
)
|
||||||
|
|
||||||
return [self.create_json_message(item) for item in result]
|
return [self.create_json_message(item) for item in result]
|
||||||
|
|
||||||
def fetch(
|
def fetch_commits(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
|
||||||
site_url: str,
|
site_url: str,
|
||||||
access_token: str,
|
access_token: str,
|
||||||
project: str,
|
identifier: str,
|
||||||
employee: str = None,
|
employee: str,
|
||||||
start_time: str = "",
|
start_time: str,
|
||||||
end_time: str = "",
|
end_time: str,
|
||||||
change_type: str = "",
|
change_type: str,
|
||||||
|
is_repository: bool,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
domain = site_url
|
domain = site_url
|
||||||
headers = {"PRIVATE-TOKEN": access_token}
|
headers = {"PRIVATE-TOKEN": access_token}
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get all of projects
|
if is_repository:
|
||||||
url = f"{domain}/api/v4/projects"
|
# URL encode the repository path
|
||||||
response = requests.get(url, headers=headers)
|
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||||
response.raise_for_status()
|
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
|
||||||
projects = response.json()
|
else:
|
||||||
|
# Get all projects
|
||||||
|
url = f"{domain}/api/v4/projects"
|
||||||
|
response = requests.get(url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
projects = response.json()
|
||||||
|
|
||||||
filtered_projects = [p for p in projects if project == "*" or p["name"] == project]
|
filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
|
||||||
|
|
||||||
for project in filtered_projects:
|
for project in filtered_projects:
|
||||||
project_id = project["id"]
|
project_id = project["id"]
|
||||||
project_name = project["name"]
|
project_name = project["name"]
|
||||||
print(f"Project: {project_name}")
|
print(f"Project: {project_name}")
|
||||||
|
|
||||||
# Get all of project commits
|
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
|
||||||
params = {"since": start_time, "until": end_time}
|
|
||||||
if employee:
|
|
||||||
params["author"] = employee
|
|
||||||
|
|
||||||
commits_response = requests.get(commits_url, headers=headers, params=params)
|
params = {"since": start_time, "until": end_time}
|
||||||
commits_response.raise_for_status()
|
if employee:
|
||||||
commits = commits_response.json()
|
params["author"] = employee
|
||||||
|
|
||||||
for commit in commits:
|
commits_response = requests.get(commits_url, headers=headers, params=params)
|
||||||
commit_sha = commit["id"]
|
commits_response.raise_for_status()
|
||||||
author_name = commit["author_name"]
|
commits = commits_response.json()
|
||||||
|
|
||||||
|
for commit in commits:
|
||||||
|
commit_sha = commit["id"]
|
||||||
|
author_name = commit["author_name"]
|
||||||
|
|
||||||
|
if is_repository:
|
||||||
|
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
|
||||||
|
else:
|
||||||
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||||
diff_response = requests.get(diff_url, headers=headers)
|
|
||||||
diff_response.raise_for_status()
|
|
||||||
diffs = diff_response.json()
|
|
||||||
|
|
||||||
for diff in diffs:
|
diff_response = requests.get(diff_url, headers=headers)
|
||||||
# Calculate code lines of changed
|
diff_response.raise_for_status()
|
||||||
added_lines = diff["diff"].count("\n+")
|
diffs = diff_response.json()
|
||||||
removed_lines = diff["diff"].count("\n-")
|
|
||||||
total_changes = added_lines + removed_lines
|
|
||||||
|
|
||||||
if change_type == "new":
|
for diff in diffs:
|
||||||
if added_lines > 1:
|
# Calculate code lines of changes
|
||||||
final_code = "".join(
|
added_lines = diff["diff"].count("\n+")
|
||||||
[
|
removed_lines = diff["diff"].count("\n-")
|
||||||
line[1:]
|
total_changes = added_lines + removed_lines
|
||||||
for line in diff["diff"].split("\n")
|
|
||||||
if line.startswith("+") and not line.startswith("+++")
|
if change_type == "new":
|
||||||
]
|
if added_lines > 1:
|
||||||
)
|
final_code = "".join(
|
||||||
results.append(
|
[
|
||||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code}
|
line[1:]
|
||||||
)
|
for line in diff["diff"].split("\n")
|
||||||
else:
|
if line.startswith("+") and not line.startswith("+++")
|
||||||
if total_changes > 1:
|
]
|
||||||
final_code = "".join(
|
)
|
||||||
[
|
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
|
||||||
line[1:]
|
else:
|
||||||
for line in diff["diff"].split("\n")
|
if total_changes > 1:
|
||||||
if (line.startswith("+") or line.startswith("-"))
|
final_code = "".join(
|
||||||
and not line.startswith("+++")
|
[
|
||||||
and not line.startswith("---")
|
line[1:]
|
||||||
]
|
for line in diff["diff"].split("\n")
|
||||||
)
|
if (line.startswith("+") or line.startswith("-"))
|
||||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
and not line.startswith("+++")
|
||||||
results.append(
|
and not line.startswith("---")
|
||||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
|
]
|
||||||
)
|
)
|
||||||
|
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||||
|
results.append(
|
||||||
|
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
|
||||||
|
)
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
print(f"Error fetching data from GitLab: {e}")
|
print(f"Error fetching data from GitLab: {e}")
|
||||||
|
|
||||||
|
@ -21,9 +21,20 @@ parameters:
|
|||||||
zh_Hans: 员工用户名
|
zh_Hans: 员工用户名
|
||||||
llm_description: User name for GitLab
|
llm_description: User name for GitLab
|
||||||
form: llm
|
form: llm
|
||||||
|
- name: repository
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: repository
|
||||||
|
zh_Hans: 仓库路径
|
||||||
|
human_description:
|
||||||
|
en_US: repository
|
||||||
|
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||||
|
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||||
|
form: llm
|
||||||
- name: project
|
- name: project
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: false
|
||||||
label:
|
label:
|
||||||
en_US: project
|
en_US: project
|
||||||
zh_Hans: 项目名
|
zh_Hans: 项目名
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import urllib.parse
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -11,14 +12,14 @@ class GitlabFilesTool(BuiltinTool):
|
|||||||
self, user_id: str, tool_parameters: dict[str, Any]
|
self, user_id: str, tool_parameters: dict[str, Any]
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
project = tool_parameters.get("project", "")
|
project = tool_parameters.get("project", "")
|
||||||
|
repository = tool_parameters.get("repository", "")
|
||||||
branch = tool_parameters.get("branch", "")
|
branch = tool_parameters.get("branch", "")
|
||||||
path = tool_parameters.get("path", "")
|
path = tool_parameters.get("path", "")
|
||||||
|
|
||||||
if not project:
|
if not project and not repository:
|
||||||
return self.create_text_message("Project is required")
|
return self.create_text_message("Either project or repository is required")
|
||||||
if not branch:
|
if not branch:
|
||||||
return self.create_text_message("Branch is required")
|
return self.create_text_message("Branch is required")
|
||||||
|
|
||||||
if not path:
|
if not path:
|
||||||
return self.create_text_message("Path is required")
|
return self.create_text_message("Path is required")
|
||||||
|
|
||||||
@ -30,21 +31,59 @@ class GitlabFilesTool(BuiltinTool):
|
|||||||
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
||||||
site_url = "https://gitlab.com"
|
site_url = "https://gitlab.com"
|
||||||
|
|
||||||
# Get project ID from project name
|
# Get file content
|
||||||
project_id = self.get_project_id(site_url, access_token, project)
|
if repository:
|
||||||
if not project_id:
|
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
|
||||||
return self.create_text_message(f"Project '{project}' not found.")
|
else:
|
||||||
|
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
|
||||||
# Get commit content
|
|
||||||
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
|
|
||||||
|
|
||||||
return [self.create_json_message(item) for item in result]
|
return [self.create_json_message(item) for item in result]
|
||||||
|
|
||||||
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
|
def fetch_files(
|
||||||
parts = path.split("/", 1)
|
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
|
||||||
if len(parts) < 2:
|
) -> list[dict[str, Any]]:
|
||||||
return None, None
|
domain = site_url
|
||||||
return parts[0], parts[1]
|
headers = {"PRIVATE-TOKEN": access_token}
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_repository:
|
||||||
|
# URL encode the repository path
|
||||||
|
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||||
|
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
|
||||||
|
else:
|
||||||
|
# Get project ID from project name
|
||||||
|
project_id = self.get_project_id(site_url, access_token, identifier)
|
||||||
|
if not project_id:
|
||||||
|
return self.create_text_message(f"Project '{identifier}' not found.")
|
||||||
|
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||||
|
|
||||||
|
response = requests.get(tree_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
items = response.json()
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
item_path = item["path"]
|
||||||
|
if item["type"] == "tree": # It's a directory
|
||||||
|
results.extend(
|
||||||
|
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
|
||||||
|
)
|
||||||
|
else: # It's a file
|
||||||
|
if is_repository:
|
||||||
|
file_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/files/{item_path}/raw?ref={branch}"
|
||||||
|
else:
|
||||||
|
file_url = (
|
||||||
|
f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_response = requests.get(file_url, headers=headers)
|
||||||
|
file_response.raise_for_status()
|
||||||
|
file_content = file_response.text
|
||||||
|
results.append({"path": item_path, "branch": branch, "content": file_content})
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Error fetching data from GitLab: {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
|
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
|
||||||
headers = {"PRIVATE-TOKEN": access_token}
|
headers = {"PRIVATE-TOKEN": access_token}
|
||||||
@ -59,32 +98,3 @@ class GitlabFilesTool(BuiltinTool):
|
|||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
print(f"Error fetching project ID from GitLab: {e}")
|
print(f"Error fetching project ID from GitLab: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def fetch(
|
|
||||||
self, user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
domain = site_url
|
|
||||||
headers = {"PRIVATE-TOKEN": access_token}
|
|
||||||
results = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# List files and directories in the given path
|
|
||||||
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
|
||||||
response = requests.get(url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
items = response.json()
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
item_path = item["path"]
|
|
||||||
if item["type"] == "tree": # It's a directory
|
|
||||||
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
|
|
||||||
else: # It's a file
|
|
||||||
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
|
||||||
file_response = requests.get(file_url, headers=headers)
|
|
||||||
file_response.raise_for_status()
|
|
||||||
file_content = file_response.text
|
|
||||||
results.append({"path": item_path, "branch": branch, "content": file_content})
|
|
||||||
except requests.RequestException as e:
|
|
||||||
print(f"Error fetching data from GitLab: {e}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
@ -10,9 +10,20 @@ description:
|
|||||||
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
|
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
|
||||||
llm: A tool for query GitLab files, Input should be a exists file or directory path.
|
llm: A tool for query GitLab files, Input should be a exists file or directory path.
|
||||||
parameters:
|
parameters:
|
||||||
|
- name: repository
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
label:
|
||||||
|
en_US: repository
|
||||||
|
zh_Hans: 仓库路径
|
||||||
|
human_description:
|
||||||
|
en_US: repository
|
||||||
|
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||||
|
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||||
|
form: llm
|
||||||
- name: project
|
- name: project
|
||||||
type: string
|
type: string
|
||||||
required: true
|
required: false
|
||||||
label:
|
label:
|
||||||
en_US: project
|
en_US: project
|
||||||
zh_Hans: 项目
|
zh_Hans: 项目
|
||||||
|
@ -142,7 +142,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
|||||||
for control in controls:
|
for control in controls:
|
||||||
control_type_id = self.get_real_type_id(control)
|
control_type_id = self.get_real_type_id(control)
|
||||||
if (control_type_id in self._get_ignore_types()) or (
|
if (control_type_id in self._get_ignore_types()) or (
|
||||||
allow_fields and not control["controlId"] in allow_fields
|
allow_fields and control["controlId"] not in allow_fields
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -201,9 +201,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
|||||||
elif value.startswith('[{"organizeId"'):
|
elif value.startswith('[{"organizeId"'):
|
||||||
value = json.loads(value)
|
value = json.loads(value)
|
||||||
value = "、".join([item["organizeName"] for item in value])
|
value = "、".join([item["organizeName"] for item in value])
|
||||||
elif value.startswith('[{"file_id"'):
|
elif value.startswith('[{"file_id"') or value == "[]":
|
||||||
value = ""
|
|
||||||
elif value == "[]":
|
|
||||||
value = ""
|
value = ""
|
||||||
elif hasattr(value, "accountId"):
|
elif hasattr(value, "accountId"):
|
||||||
value = value["fullname"]
|
value = value["fullname"]
|
||||||
|
@ -67,7 +67,7 @@ class ListWorksheetsTool(BuiltinTool):
|
|||||||
items = []
|
items = []
|
||||||
tables = ""
|
tables = ""
|
||||||
for item in section.get("items", []):
|
for item in section.get("items", []):
|
||||||
if item.get("type") == 0 and (not "notes" in item or item.get("notes") != "NO"):
|
if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"):
|
||||||
if type == "json":
|
if type == "json":
|
||||||
filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")}
|
filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")}
|
||||||
items.append(filtered_item)
|
items.append(filtered_item)
|
||||||
|
@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool):
|
|||||||
self.create_blob_message(
|
self.create_blob_message(
|
||||||
blob=b64decode(client_result.image_file),
|
blob=b64decode(client_result.image_file),
|
||||||
meta={"mime_type": f"image/{client_result.image_type}"},
|
meta={"mime_type": f"image/{client_result.image_type}"},
|
||||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
save_as=self.VariableKey.IMAGE.value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool):
|
|||||||
models_data=[],
|
models_data=[],
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
recursive=False if result_type == "first sd_name" or result_type == "first name sd_name pair" else True,
|
recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"),
|
||||||
)
|
)
|
||||||
|
|
||||||
result_str = ""
|
result_str = ""
|
||||||
|
@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
|||||||
self.create_blob_message(
|
self.create_blob_message(
|
||||||
blob=b64decode(image_encoded),
|
blob=b64decode(image_encoded),
|
||||||
meta={"mime_type": f"image/{image.image_type}"},
|
meta={"mime_type": f"image/{image.image_type}"},
|
||||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
save_as=self.VariableKey.IMAGE.value,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,14 +39,14 @@ class QRCodeGeneratorTool(BuiltinTool):
|
|||||||
|
|
||||||
# get error_correction
|
# get error_correction
|
||||||
error_correction = tool_parameters.get("error_correction", "")
|
error_correction = tool_parameters.get("error_correction", "")
|
||||||
if error_correction not in self.error_correction_levels.keys():
|
if error_correction not in self.error_correction_levels:
|
||||||
return self.create_text_message("Invalid parameter error_correction")
|
return self.create_text_message("Invalid parameter error_correction")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = self._generate_qrcode(content, border, error_correction)
|
image = self._generate_qrcode(content, border, error_correction)
|
||||||
image_bytes = self._image_to_byte_array(image)
|
image_bytes = self._image_to_byte_array(image)
|
||||||
return self.create_blob_message(
|
return self.create_blob_message(
|
||||||
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(f"Failed to generate QR code for content: {content}")
|
logging.exception(f"Failed to generate QR code for content: {content}")
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user