Merge branch 'main' of github.com:langgenius/dify into feat/plugins

This commit is contained in:
Yi 2024-09-18 20:57:52 +08:00
commit e8127756e0
245 changed files with 2882 additions and 1210 deletions

9
.gitignore vendored
View File

@ -153,6 +153,9 @@ docker-legacy/volumes/etcd/*
docker-legacy/volumes/minio/* docker-legacy/volumes/minio/*
docker-legacy/volumes/milvus/* docker-legacy/volumes/milvus/*
docker-legacy/volumes/chroma/* docker-legacy/volumes/chroma/*
docker-legacy/volumes/opensearch/data/*
docker-legacy/volumes/pgvectors/data/*
docker-legacy/volumes/pgvector/data/*
docker/volumes/app/storage/* docker/volumes/app/storage/*
docker/volumes/certbot/* docker/volumes/certbot/*
@ -164,6 +167,12 @@ docker/volumes/etcd/*
docker/volumes/minio/* docker/volumes/minio/*
docker/volumes/milvus/* docker/volumes/milvus/*
docker/volumes/chroma/* docker/volumes/chroma/*
docker/volumes/opensearch/data/*
docker/volumes/myscale/data/*
docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf
docker/middleware.env docker/middleware.env

View File

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in ["console", "inner_api"]: if request.blueprint not in {"console", "inner_api"}:
return None return None
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")

View File

@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.") @click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str): def vdb_migrate(scope: str):
if scope in ["knowledge", "all"]: if scope in {"knowledge", "all"}:
migrate_knowledge_vector_database() migrate_knowledge_vector_database()
if scope in ["annotation", "all"]: if scope in {"annotation", "all"}:
migrate_annotation_vector_database() migrate_annotation_vector_database()

View File

@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info) account = _generate_account(provider, user_info)
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:

View File

@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
if document.indexing_status in ["completed", "error"]: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = [] info_list = []
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in ["completed", "error"]: if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
# format document files info # format document files info
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit() db.session.commit()
elif action == "resume": elif action == "resume":
if document.indexing_status not in ["paused", "error"]: if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.") raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None document.paused_by = None

View File

@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id): def delete(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id, "app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned, "is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at, "last_used_at": installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"], "editable": current_user.role in {"owner", "admin"},
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
} }
for installed_app in installed_apps for installed_app in installed_apps

View File

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
message_id = str(message_id) message_id = str(message_id)

View File

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError() raise TooManyFilesError()
extension = file.filename.split(".")[-1] extension = file.filename.split(".")[-1]
if extension.lower() not in ["svg", "png"]: if extension.lower() not in {"svg", "png"}:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
try: try:

View File

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -79,7 +79,7 @@ class TextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id): def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id): def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
try: try:

View File

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -78,7 +78,7 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):

View File

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource): class ChatApi(WebApiResource):
def post(self, app_model, end_user): def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource): class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id): def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource): class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource): class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id): def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)

View File

@ -90,7 +90,7 @@ class CotAgentOutputParser:
if not in_code_block and not in_json: if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0: if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ["\n", " ", ""]: if last_character not in {"\n", " ", ""}:
index += steps index += steps
yield delta yield delta
continue continue
@ -117,7 +117,7 @@ class CotAgentOutputParser:
action_idx = 0 action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ["\n", " ", ""]: if last_character not in {"\n", " ", ""}:
index += steps index += steps
yield delta yield delta
continue continue

View File

@ -29,7 +29,7 @@ class BaseAppConfigManager:
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.file_upload = FileUploadConfigManager.convert( additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
) )
additional_features.opening_statement, additional_features.suggested_questions = ( additional_features.opening_statement, additional_features.suggested_questions = (

View File

@ -18,7 +18,7 @@ class AgentConfigManager:
if agent_strategy == "function_call": if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == "cot" or agent_strategy == "react": elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else: else:
# old configs, try to detect default strategy # old configs, try to detect default strategy
@ -43,10 +43,10 @@ class AgentConfigManager:
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router", "react_router",
"router", "router",
]: }:
agent_prompt = agent_dict.get("prompt", None) or {} agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode # check model mode
model_mode = config.get("model", {}).get("mode", "completion") model_mode = config.get("model", {}).get("mode", "completion")

View File

@ -167,7 +167,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False has_datasets = False
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
for tool in config["agent_mode"]["tools"]: for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key == "dataset": if key == "dataset":

View File

@ -42,12 +42,12 @@ class BasicVariablesConfigManager:
variable=variable["variable"], type=variable["type"], config=variable["config"] variable=variable["variable"], type=variable["type"], config=variable["config"]
) )
) )
elif variable_type in [ elif variable_type in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER, VariableEntityType.NUMBER,
VariableEntityType.SELECT, VariableEntityType.SELECT,
]: }:
variable = variables[variable_type] variable = variables[variable_type]
variable_entities.append( variable_entities.append(
VariableEntity( VariableEntity(
@ -97,7 +97,7 @@ class BasicVariablesConfigManager:
variables = [] variables = []
for item in config["user_input_form"]: for item in config["user_input_form"]:
key = list(item.keys())[0] key = list(item.keys())[0]
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key] form_item = item[key]

View File

@ -54,14 +54,14 @@ class FileUploadConfigManager:
if is_vision: if is_vision:
detail = config["file_upload"]["image"]["detail"] detail = config["file_upload"]["image"]["detail"]
if detail not in ["high", "low"]: if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']") raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"] transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list): if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type") raise ValueError("transfer_methods must be of list type")
for method in transfer_methods: for method in transfer_methods:
if method not in ["remote_url", "local_file"]: if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"] return config, ["file_upload"]

View File

@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
raise ValueError("Workflow not initialized") raise ValueError("Workflow not initialized")
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user: if end_user:
user_id = end_user.session_id user_id = end_user.session_id
@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER else UserFrom.END_USER
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert( def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]: ) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse): if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response) return cls.convert_blocking_full_response(response)
else: else:

View File

@ -22,11 +22,11 @@ class BaseAppGenerator:
return var.default or "" return var.default or ""
if ( if (
var.type var.type
in ( in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT, VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
) }
and user_input_value and user_input_value
and not isinstance(user_input_value, str) and not isinstance(user_input_value, str)
): ):
@ -44,7 +44,7 @@ class BaseAppGenerator:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}") raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")

View File

@ -32,7 +32,7 @@ class AppQueueManager:
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex( redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
) )
@ -118,7 +118,7 @@ class AppQueueManager:
if result is None: if result is None:
return return
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}": if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return return

View File

@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
# get from source # get from source
end_user_id = None end_user_id = None
account_id = None account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
from_source = "api" from_source = "api"
end_user_id = application_generate_entity.user_id end_user_id = application_generate_entity.user_id
else: else:
@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model model_id = application_generate_entity.model_conf.model
override_model_configs = None override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
AppMode.AGENT_CHAT, AppMode.AGENT_CHAT,
AppMode.CHAT, AppMode.CHAT,
AppMode.COMPLETION, AppMode.COMPLETION,
]: }:
override_model_configs = app_config.app_model_config_dict override_model_configs = app_config.app_model_config_dict
# get conversation introduction # get conversation introduction

View File

@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
user_id = None user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user: if end_user:
user_id = end_user.session_id user_id = end_user.session_id
@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER else UserFrom.END_USER
), ),
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,

View File

@ -63,7 +63,7 @@ class AnnotationReplyFeature:
score = documents[0].metadata["score"] score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id) annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation: if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}:
from_source = "api" from_source = "api"
else: else:
from_source = "console" from_source = "console"

View File

@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message, self._message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and self._application_generate_entity.conversation_id is None, and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras, extras=self._application_generate_entity.extras,
) )

View File

@ -383,7 +383,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution :param workflow_node_execution: workflow node execution
:return: :return:
""" """
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
response = NodeStartStreamResponse( response = NodeStartStreamResponse(
@ -430,7 +430,7 @@ class WorkflowCycleManage:
:param workflow_node_execution: workflow node execution :param workflow_node_execution: workflow node execution
:return: :return:
""" """
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
return NodeFinishStreamResponse( return NodeFinishStreamResponse(

View File

@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler:
source="app", source="app",
source_app_id=self._app_id, source_app_id=self._app_id,
created_by_role=( created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
), ),
created_by=self._user_id, created_by=self._user_id,
) )

View File

@ -65,7 +65,7 @@ class CacheEmbedding(Embeddings):
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
except Exception as e: except Exception as e:
logging.exception("Failed transform embedding: ", e) logging.exception("Failed transform embedding: %s", e)
cache_embeddings = [] cache_embeddings = []
try: try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
@ -85,7 +85,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback() db.session.rollback()
except Exception as ex: except Exception as ex:
db.session.rollback() db.session.rollback()
logger.error("Failed to embed documents: ", ex) logger.error("Failed to embed documents: %s", ex)
raise ex raise ex
return text_embeddings return text_embeddings
@ -116,10 +116,7 @@ class CacheEmbedding(Embeddings):
# Transform to string # Transform to string
encoded_str = encoded_vector.decode("utf-8") encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str) redis_client.setex(embedding_cache_key, 600, encoded_str)
except Exception as ex:
except IntegrityError: logging.exception("Failed to add embedding to redis %s", ex)
db.session.rollback()
except:
logging.exception("Failed to add embedding to redis")
return embedding_results return embedding_results

View File

@ -292,7 +292,7 @@ class IndexingRunner:
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]: ) -> list[Document]:
# load file # load file
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return [] return []
data_source_info = dataset_document.data_source_info_dict data_source_info = dataset_document.data_source_info_dict

View File

@ -52,7 +52,7 @@ class TokenBufferMemory:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files: if files:
file_extra_config = None file_extra_config = None
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else: else:
if message.workflow_run_id: if message.workflow_run_id:

View File

@ -27,17 +27,17 @@ class ModelType(Enum):
:return: model type :return: model type
""" """
if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: if origin_model_type in {"text-generation", cls.LLM.value}:
return cls.LLM return cls.LLM
elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
return cls.TEXT_EMBEDDING return cls.TEXT_EMBEDDING
elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: elif origin_model_type in {"reranking", cls.RERANK.value}:
return cls.RERANK return cls.RERANK
elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: elif origin_model_type in {"tts", cls.TTS.value}:
return cls.TTS return cls.TTS
elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: elif origin_model_type in {"text2img", cls.TEXT2IMG.value}:
return cls.TEXT2IMG return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION

View File

@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "") mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1] base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp" f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for future in futures: for future in futures:
yield from future.result().__enter__().iter_bytes(1024) yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else: else:
response = client.audio.speech.with_streaming_response.create( response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip() model=model, voice=voice, response_format="mp3", input=content_text.strip()
) )
yield from response.__enter__().iter_bytes(1024) yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))

View File

@ -1,6 +1,6 @@
model: eu.anthropic.claude-3-haiku-20240307-v1:0 model: eu.anthropic.claude-3-haiku-20240307-v1:0
label: label:
en_US: Claude 3 Haiku(Cross Region Inference) en_US: Claude 3 Haiku(EU.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -1,6 +1,6 @@
model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0 model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0
label: label:
en_US: Claude 3.5 Sonnet(Cross Region Inference) en_US: Claude 3.5 Sonnet(EU.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -1,6 +1,6 @@
model: eu.anthropic.claude-3-sonnet-20240229-v1:0 model: eu.anthropic.claude-3-sonnet-20240229-v1:0
label: label:
en_US: Claude 3 Sonnet(Cross Region Inference) en_US: Claude 3 Sonnet(EU.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -1,8 +1,8 @@
# standard import # standard import
import base64 import base64
import io
import json import json
import logging import logging
import mimetypes
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -17,7 +17,6 @@ from botocore.exceptions import (
ServiceNotInRegionError, ServiceNotInRegionError,
UnknownServiceError, UnknownServiceError,
) )
from PIL.Image import Image
# local import # local import
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
@ -443,8 +442,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
try: try:
url = message_content.data url = message_content.data
image_content = requests.get(url).content image_content = requests.get(url).content
with Image.open(io.BytesIO(image_content)) as img: if "?" in url:
mime_type = f"image/{img.format.lower()}" url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8") base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex: except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
base64_data = data_split[1] base64_data = data_split[1]
image_content = base64.b64decode(base64_data) image_content = base64.b64decode(base64_data)
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp" f"only support image/jpeg, image/png, image/gif, and image/webp"
@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
if error_code == "AccessDeniedException": if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg) return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]: elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg) return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg) return InvokeRateLimitError(error_msg)
elif error_code in [ elif error_code in {
"ModelTimeoutException", "ModelTimeoutException",
"ModelErrorException", "ModelErrorException",
"InternalServerException", "InternalServerException",
"ModelNotReadyException", "ModelNotReadyException",
]: }:
return InvokeServerUnavailableError(error_msg) return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException": elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg) return InvokeConnectionError(error_msg)

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-haiku-20240307-v1:0 model: us.anthropic.claude-3-haiku-20240307-v1:0
label: label:
en_US: Claude 3 Haiku(Cross Region Inference) en_US: Claude 3 Haiku(US.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-opus-20240229-v1:0 model: us.anthropic.claude-3-opus-20240229-v1:0
label: label:
en_US: Claude 3 Opus(Cross Region Inference) en_US: Claude 3 Opus(US.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-5-sonnet-20240620-v1:0 model: us.anthropic.claude-3-5-sonnet-20240620-v1:0
label: label:
en_US: Claude 3.5 Sonnet(Cross Region Inference) en_US: Claude 3.5 Sonnet(US.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -1,6 +1,6 @@
model: us.anthropic.claude-3-sonnet-20240229-v1:0 model: us.anthropic.claude-3-sonnet-20240229-v1:0
label: label:
en_US: Claude 3 Sonnet(Cross Region Inference) en_US: Claude 3 Sonnet(US.Cross Region Inference)
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought

View File

@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
if error_code == "AccessDeniedException": if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg) return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]: elif error_code in {"ResourceNotFoundException", "ValidationException"}:
return InvokeBadRequestError(error_msg) return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}:
return InvokeRateLimitError(error_msg) return InvokeRateLimitError(error_msg)
elif error_code in [ elif error_code in {
"ModelTimeoutException", "ModelTimeoutException",
"ModelErrorException", "ModelErrorException",
"InternalServerException", "InternalServerException",
"ModelNotReadyException", "ModelNotReadyException",
]: }:
return InvokeServerUnavailableError(error_msg) return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException": elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg) return InvokeConnectionError(error_msg)

View File

@ -6,10 +6,10 @@ from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.client as client
import requests import requests
from google.api_core import exceptions
from google.generativeai import client
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part from google.generativeai.types.content_types import to_part
from PIL import Image from PIL import Image

View File

@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
if "huggingfacehub_api_type" not in credentials: if "huggingfacehub_api_type" not in credentials:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.")
if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}:
raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.")
if "huggingfacehub_api_token" not in credentials: if "huggingfacehub_api_token" not in credentials:
@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
credentials["huggingfacehub_api_token"], model credentials["huggingfacehub_api_token"], model
) )
if credentials["task_type"] not in ("text2text-generation", "text-generation"): if credentials["task_type"] not in {"text2text-generation", "text-generation"}:
raise CredentialsValidateFailedError( raise CredentialsValidateFailedError(
"Huggingface Hub Task Type must be one of text2text-generation, text-generation." "Huggingface Hub Task Type must be one of text2text-generation, text-generation."
) )

View File

@ -49,8 +49,7 @@ class HuggingfaceTeiRerankModel(RerankModel):
return RerankResult(model=model, docs=[]) return RerankResult(model=model, docs=[])
server_url = credentials["server_url"] server_url = credentials["server_url"]
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
try: try:
results = TeiHelper.invoke_rerank(server_url, query, docs) results = TeiHelper.invoke_rerank(server_url, query, docs)

View File

@ -75,7 +75,7 @@ class TeiHelper:
if len(model_type.keys()) < 1: if len(model_type.keys()) < 1:
raise RuntimeError("model_type is empty") raise RuntimeError("model_type is empty")
model_type = list(model_type.keys())[0] model_type = list(model_type.keys())[0]
if model_type not in ["embedding", "reranker"]: if model_type not in {"embedding", "reranker"}:
raise RuntimeError(f"invalid model_type: {model_type}") raise RuntimeError(f"invalid model_type: {model_type}")
max_input_length = response_json.get("max_input_length", 512) max_input_length = response_json.get("max_input_length", 512)

View File

@ -42,8 +42,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
""" """
server_url = credentials["server_url"] server_url = credentials["server_url"]
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
# get model properties # get model properties
context_size = self._get_context_size(model, credentials) context_size = self._get_context_size(model, credentials)
@ -119,8 +118,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
num_tokens = 0 num_tokens = 0
server_url = credentials["server_url"] server_url = credentials["server_url"]
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
num_tokens = sum(len(tokens) for tokens in batch_tokens) num_tokens = sum(len(tokens) for tokens in batch_tokens)

View File

@ -2,3 +2,4 @@
- hunyuan-standard - hunyuan-standard
- hunyuan-standard-256k - hunyuan-standard-256k
- hunyuan-pro - hunyuan-pro
- hunyuan-turbo

View File

@ -0,0 +1,38 @@
model: hunyuan-turbo
label:
zh_Hans: hunyuan-turbo
en_US: hunyuan-turbo
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
- name: enable_enhance
label:
zh_Hans: 功能增强
en_US: Enable Enhancement
type: boolean
help:
zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。
en_US: Allow the model to perform external search to enhance the generation results.
required: false
default: true
pricing:
input: '0.015'
output: '0.05'
unit: '0.001'
currency: RMB

View File

@ -18,9 +18,9 @@ class JinaProvider(ModelProvider):
try: try:
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
# Use `jina-embeddings-v2-base-en` model for validate, # Use `jina-embeddings-v3` model for validate,
# no matter what model you pass in, text completion model or chat model # no matter what model you pass in, text completion model or chat model
model_instance.validate_credentials(model="jina-embeddings-v2-base-en", credentials=credentials) model_instance.validate_credentials(model="jina-embeddings-v3", credentials=credentials)
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ex raise ex
except Exception as ex: except Exception as ex:

View File

@ -48,8 +48,7 @@ class JinaRerankModel(RerankModel):
return RerankResult(model=model, docs=[]) return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://api.jina.ai/v1") base_url = credentials.get("base_url", "https://api.jina.ai/v1")
if base_url.endswith("/"): base_url = base_url.removesuffix("/")
base_url = base_url[:-1]
try: try:
response = httpx.post( response = httpx.post(

View File

@ -0,0 +1,9 @@
model: jina-embeddings-v3
model_type: text-embedding
model_properties:
context_size: 8192
max_chunks: 2048
pricing:
input: '0.001'
unit: '0.001'
currency: USD

View File

@ -44,8 +44,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("api_key is required") raise CredentialsValidateFailedError("api_key is required")
base_url = credentials.get("base_url", self.api_base) base_url = credentials.get("base_url", self.api_base)
if base_url.endswith("/"): base_url = base_url.removesuffix("/")
base_url = base_url[:-1]
url = base_url + "/embeddings" url = base_url + "/embeddings"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
@ -57,6 +56,9 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]}
if model == "jina-embeddings-v3":
data["task"] = "text-matching"
try: try:
response = post(url, headers=headers, data=dumps(data)) response = post(url, headers=headers, data=dumps(data))
except Exception as e: except Exception as e:

View File

@ -100,9 +100,9 @@ class MinimaxChatCompletion:
return self._handle_chat_generate_response(response) return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027: if code in {1000, 1001, 1013, 1027}:
raise InternalServerError(msg) raise InternalServerError(msg)
elif code == 1002 or code == 1039: elif code in {1002, 1039}:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)
elif code == 1004: elif code == 1004:
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)

View File

@ -105,9 +105,9 @@ class MinimaxChatCompletionPro:
return self._handle_chat_generate_response(response) return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027: if code in {1000, 1001, 1013, 1027}:
raise InternalServerError(msg) raise InternalServerError(msg)
elif code == 1002 or code == 1039: elif code in {1002, 1039}:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)
elif code == 1004: elif code == 1004:
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)

View File

@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
raise CredentialsValidateFailedError("Invalid api key") raise CredentialsValidateFailedError("Invalid api key")
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001: if code in {1000, 1001}:
raise InternalServerError(msg) raise InternalServerError(msg)
elif code == 1002: elif code == 1002:
raise RateLimitReachedError(msg) raise RateLimitReachedError(msg)

View File

@ -31,3 +31,4 @@ pricing:
output: '0.002' output: '0.002'
unit: '0.001' unit: '0.001'
currency: USD currency: USD
deprecated: true

View File

@ -31,3 +31,4 @@ pricing:
output: '0.004' output: '0.004'
unit: '0.001' unit: '0.001'
currency: USD currency: USD
deprecated: true

View File

@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
model_mode = self.get_model_mode(base_model, credentials) model_mode = self.get_model_mode(base_model, credentials)
# transform response format # transform response format
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
stop = stop or [] stop = stop or []
if model_mode == LLMMode.CHAT: if model_mode == LLMMode.CHAT:
# chat model # chat model
@ -615,10 +615,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
block_as_stream = False block_as_stream = False
if model.startswith("o1"): if model.startswith("o1"):
block_as_stream = True if stream:
stream = False block_as_stream = True
if "stream_options" in extra_model_kwargs: stream = False
del extra_model_kwargs["stream_options"]
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs: if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"] del extra_model_kwargs["stop"]

View File

@ -11,9 +11,9 @@ model_properties:
parameter_rules: parameter_rules:
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 65563 default: 65536
min: 1 min: 1
max: 65563 max: 65536
- name: response_format - name: response_format
label: label:
zh_Hans: 回复格式 zh_Hans: 回复格式

View File

@ -11,9 +11,9 @@ model_properties:
parameter_rules: parameter_rules:
- name: max_tokens - name: max_tokens
use_template: max_tokens use_template: max_tokens
default: 65563 default: 65536
min: 1 min: 1
max: 65563 max: 65536
- name: response_format - name: response_format
label: label:
zh_Hans: 回复格式 zh_Hans: 回复格式

View File

@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
for i in range(len(sentences)) for i in range(len(sentences))
] ]
for future in futures: for future in futures:
yield from future.result().__enter__().iter_bytes(1024) yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801
else: else:
response = client.audio.speech.with_streaming_response.create( response = client.audio.speech.with_streaming_response.create(
model=model, voice=voice, response_format="mp3", input=content_text.strip() model=model, voice=voice, response_format="mp3", input=content_text.strip()
) )
yield from response.__enter__().iter_bytes(1024) yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801
except Exception as ex: except Exception as ex:
raise InvokeBadRequestError(str(ex)) raise InvokeBadRequestError(str(ex))

View File

@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
credentials["endpoint_url"] = "https://openrouter.ai/api/v1" credentials["endpoint_url"] = "https://openrouter.ai/api/v1"
credentials["mode"] = self.get_model_mode(model).value credentials["mode"] = self.get_model_mode(model).value
credentials["function_calling_type"] = "tool_call" credentials["function_calling_type"] = "tool_call"
return
def _invoke( def _invoke(
self, self,

View File

@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
) )
for key, value in input_properties: for key, value in input_properties:
if key not in ["system_prompt", "prompt"] and "stop" not in key: if key not in {"system_prompt", "prompt"} and "stop" not in key:
value_type = value.get("type") value_type = value.get("type")
if not value_type: if not value_type:

View File

@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
) )
for input_property in input_properties: for input_property in input_properties:
if input_property[0] in ("text", "texts", "inputs"): if input_property[0] in {"text", "texts", "inputs"}:
text_input_key = input_property[0] text_input_key = input_property[0]
return text_input_key return text_input_key
@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
def _generate_embeddings_by_text_input_key( def _generate_embeddings_by_text_input_key(
client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str]
) -> list[list[float]]: ) -> list[list[float]]:
if text_input_key in ("text", "inputs"): if text_input_key in {"text", "inputs"}:
embeddings = [] embeddings = []
for text in texts: for text in texts:
result = client.run(replicate_model_version, input={text_input_key: text}) result = client.run(replicate_model_version, input={text_input_key: text})

View File

@ -30,8 +30,7 @@ class SiliconflowRerankModel(RerankModel):
return RerankResult(model=model, docs=[]) return RerankResult(model=model, docs=[])
base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1")
if base_url.endswith("/"): base_url = base_url.removesuffix("/")
base_url = base_url[:-1]
try: try:
response = httpx.post( response = httpx.post(
base_url + "/rerank", base_url + "/rerank",

View File

@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: :return:
""" """
if model in ["qwen-turbo-chat", "qwen-plus-chat"]: if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
model = model.replace("-chat", "") model = model.replace("-chat", "")
if model == "farui-plus": if model == "farui-plus":
model = "qwen-farui-plus" model = "qwen-farui-plus"
@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
mode = self.get_model_mode(model, credentials) mode = self.get_model_mode(model, credentials)
if model in ["qwen-turbo-chat", "qwen-plus-chat"]: if model in {"qwen-turbo-chat", "qwen-plus-chat"}:
model = model.replace("-chat", "") model = model.replace("-chat", "")
extra_model_kwargs = {} extra_model_kwargs = {}
@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: llm response :return: llm response
""" """
if response.status_code != 200 and response.status_code != HTTPStatus.OK: if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError(response.message) raise ServiceUnavailableError(response.message)
# transform assistant message to prompt message # transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage( assistant_prompt_message = AssistantPromptMessage(
@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
full_text = "" full_text = ""
tool_calls = [] tool_calls = []
for index, response in enumerate(responses): for index, response in enumerate(responses):
if response.status_code != 200 and response.status_code != HTTPStatus.OK: if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError( raise ServiceUnavailableError(
f"Failed to invoke model {model}, status code: {response.status_code}, " f"Failed to invoke model {model}, status code: {response.status_code}, "
f"message: {response.message}" f"message: {response.message}"

View File

@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
""" """
Code block mode wrapper for invoking large language model Code block mode wrapper for invoking large language model
""" """
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
stop = stop or [] stop = stop or []
self._transform_chat_json_prompts( self._transform_chat_json_prompts(
model=model, model=model,

View File

@ -5,7 +5,6 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
import google.api_core.exceptions as exceptions
import google.auth.transport.requests import google.auth.transport.requests
import vertexai.generative_models as glm import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream from anthropic import AnthropicVertex, Stream
@ -17,6 +16,7 @@ from anthropic.types import (
MessageStopEvent, MessageStopEvent,
MessageStreamEvent, MessageStreamEvent,
) )
from google.api_core import exceptions
from google.cloud import aiplatform from google.cloud import aiplatform
from google.oauth2 import service_account from google.oauth2 import service_account
from PIL import Image from PIL import Image
@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
mime_type = data_split[0].replace("data:", "") mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1] base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(
f"Unsupported image type {mime_type}, " f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp" f"only support image/jpeg, image/png, image/gif, and image/webp"

View File

@ -96,7 +96,6 @@ class Signer:
signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials)
return
@staticmethod @staticmethod
def hashed_canonical_request_v4(request, meta): def hashed_canonical_request_v4(request, meta):
@ -105,7 +104,7 @@ class Signer:
signed_headers = {} signed_headers = {}
for key in request.headers: for key in request.headers:
if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"):
signed_headers[key.lower()] = request.headers[key] signed_headers[key.lower()] = request.headers[key]
if "host" in signed_headers: if "host" in signed_headers:

View File

@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
""" """
Code block mode wrapper for invoking large language model Code block mode wrapper for invoking large language model
""" """
if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
response_format = model_parameters["response_format"] response_format = model_parameters["response_format"]
stop = stop or [] stop = stop or []
self._transform_json_prompts( self._transform_json_prompts(

View File

@ -459,8 +459,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if "server_url" not in credentials: if "server_url" not in credentials:
raise CredentialsValidateFailedError("server_url is required in credentials") raise CredentialsValidateFailedError("server_url is required in credentials")
if credentials["server_url"].endswith("/"): credentials["server_url"] = credentials["server_url"].removesuffix("/")
credentials["server_url"] = credentials["server_url"][:-1]
api_key = credentials.get("api_key") or "abc" api_key = credentials.get("api_key") or "abc"

View File

@ -50,8 +50,7 @@ class XinferenceRerankModel(RerankModel):
server_url = credentials["server_url"] server_url = credentials["server_url"]
model_uid = credentials["model_uid"] model_uid = credentials["model_uid"]
api_key = credentials.get("api_key") api_key = credentials.get("api_key")
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True}
@ -98,8 +97,7 @@ class XinferenceRerankModel(RerankModel):
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials["server_url"].endswith("/"): credentials["server_url"] = credentials["server_url"].removesuffix("/")
credentials["server_url"] = credentials["server_url"][:-1]
# initialize client # initialize client
client = Client( client = Client(

View File

@ -45,8 +45,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials["server_url"].endswith("/"): credentials["server_url"] = credentials["server_url"].removesuffix("/")
credentials["server_url"] = credentials["server_url"][:-1]
# initialize client # initialize client
client = Client( client = Client(
@ -116,8 +115,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
server_url = credentials["server_url"] server_url = credentials["server_url"]
model_uid = credentials["model_uid"] model_uid = credentials["model_uid"]
api_key = credentials.get("api_key") api_key = credentials.get("api_key")
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try: try:

View File

@ -45,8 +45,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = credentials["server_url"] server_url = credentials["server_url"]
model_uid = credentials["model_uid"] model_uid = credentials["model_uid"]
api_key = credentials.get("api_key") api_key = credentials.get("api_key")
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try: try:
@ -118,8 +117,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
if extra_args.max_tokens: if extra_args.max_tokens:
credentials["max_tokens"] = extra_args.max_tokens credentials["max_tokens"] = extra_args.max_tokens
if server_url.endswith("/"): server_url = server_url.removesuffix("/")
server_url = server_url[:-1]
client = Client( client = Client(
base_url=server_url, base_url=server_url,

View File

@ -73,8 +73,7 @@ class XinferenceText2SpeechModel(TTSModel):
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials["server_url"].endswith("/"): credentials["server_url"] = credentials["server_url"].removesuffix("/")
credentials["server_url"] = credentials["server_url"][:-1]
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"], server_url=credentials["server_url"],
@ -189,8 +188,7 @@ class XinferenceText2SpeechModel(TTSModel):
:param voice: model timbre :param voice: model timbre
:return: text translated to audio file :return: text translated to audio file
""" """
if credentials["server_url"].endswith("/"): credentials["server_url"] = credentials["server_url"].removesuffix("/")
credentials["server_url"] = credentials["server_url"][:-1]
try: try:
api_key = credentials.get("api_key") api_key = credentials.get("api_key")

View File

@ -103,7 +103,7 @@ class XinferenceHelper:
model_handle_type = "embedding" model_handle_type = "embedding"
elif response_json.get("model_type") == "audio": elif response_json.get("model_type") == "audio":
model_handle_type = "audio" model_handle_type = "audio"
if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}:
model_ability.append("text-to-audio") model_ability.append("text-to-audio")
else: else:
model_ability.append("audio-to-text") model_ability.append("audio-to-text")

View File

@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
new_prompt_messages: list[PromptMessage] = [] new_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
copy_prompt_message = prompt_message.copy() copy_prompt_message = prompt_message.copy()
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list): if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v' # check if model is 'glm-4v'
if model not in ("glm-4v", "glm-4v-plus"): if model not in {"glm-4v", "glm-4v-plus"}:
# not support list message # not support list message
continue continue
# get image and # get image and
@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
): ):
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else: else:
if ( if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}:
copy_prompt_message.role == PromptMessageRole.USER
or copy_prompt_message.role == PromptMessageRole.TOOL
):
new_prompt_messages.append(copy_prompt_message) new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.SYSTEM: elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else: else:
new_prompt_messages.append(copy_prompt_message) new_prompt_messages.append(copy_prompt_message)
if model == "glm-4v" or model == "glm-4v-plus": if model in {"glm-4v", "glm-4v-plus"}:
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else: else:
params = {"model": model, "messages": [], **model_parameters} params = {"model": model, "messages": [], **model_parameters}
@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# chatglm model # chatglm model
for prompt_message in new_prompt_messages: for prompt_message in new_prompt_messages:
# merge system message to user message # merge system message to user message
if ( if prompt_message.role in {
prompt_message.role == PromptMessageRole.SYSTEM PromptMessageRole.SYSTEM,
or prompt_message.role == PromptMessageRole.TOOL PromptMessageRole.TOOL,
or prompt_message.role == PromptMessageRole.USER PromptMessageRole.USER,
): }:
if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user":
params["messages"][-1]["content"] += "\n\n" + prompt_message.content params["messages"][-1]["content"] += "\n\n" + prompt_message.content
else: else:

View File

@ -127,8 +127,7 @@ class SSELineParser:
field, _p, value = line.partition(":") field, _p, value = line.partition(":")
if value.startswith(" "): value = value.removeprefix(" ")
value = value[1:]
if field == "data": if field == "data":
self._data.append(value) self._data.append(value)
elif field == "event": elif field == "event":

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
from .fine_tuning_job import FineTuningJob as FineTuningJob from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob from .fine_tuning_job_event import FineTuningJobEvent
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent

View File

@ -75,7 +75,7 @@ class CommonValidator:
if not isinstance(value, str): if not isinstance(value, str):
raise ValueError(f"Variable {credential_form_schema.variable} should be string") raise ValueError(f"Variable {credential_form_schema.variable} should be string")
if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}:
# If the value is in options, no validation is performed # If the value is in options, no validation is performed
if credential_form_schema.options: if credential_form_schema.options:
if value not in [option.value for option in credential_form_schema.options]: if value not in [option.value for option in credential_form_schema.options]:
@ -83,7 +83,7 @@ class CommonValidator:
if credential_form_schema.type == FormType.SWITCH: if credential_form_schema.type == FormType.SWITCH:
# If the value is not in ['true', 'false'], an exception is thrown # If the value is not in ['true', 'false'], an exception is thrown
if value.lower() not in ["true", "false"]: if value.lower() not in {"true", "false"}:
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
value = True if value.lower() == "true" else False value = True if value.lower() == "true" else False

View File

@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector):
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try: try:
parsed_url = urlparse(config.host) parsed_url = urlparse(config.host)
if parsed_url.scheme in ["http", "https"]: if parsed_url.scheme in {"http", "https"}:
hosts = f"{config.host}:{config.port}" hosts = f"{config.host}:{config.port}"
else: else:
hosts = f"http://{config.host}:{config.port}" hosts = f"http://{config.host}:{config.port}"
@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector):
return uuids return uuids
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
return self._client.exists(index=self._collection_name, id=id).__bool__() return bool(self._client.exists(index=self._collection_name, id=id))
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
for id in ids: for id in ids:

View File

@ -35,7 +35,7 @@ class MyScaleVector(BaseVector):
super().__init__(collection_name) super().__init__(collection_name)
self._config = config self._config = config
self._metric = metric self._metric = metric
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC
self._client = get_client( self._client = get_client(
host=config.host, host=config.host,
port=config.port, port=config.port,
@ -92,7 +92,7 @@ class MyScaleVector(BaseVector):
@staticmethod @staticmethod
def escape_str(value: Any) -> str: def escape_str(value: Any) -> str:
return "".join(" " if c in ("\\", "'") else c for c in str(value)) return "".join(" " if c in {"\\", "'"} else c for c in str(value))
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")

View File

@ -223,15 +223,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query) words = pseg.cut(query)
current_entity = "" current_entity = ""
for word, pos in words: for word, pos in words:
if ( if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
pos == "nr"
or pos == "Ng"
or pos == "eng"
or pos == "nz"
or pos == "n"
or pos == "ORG"
or pos == "v"
): # nr: 人名, ns: 地名, nt: 机构名
current_entity += word current_entity += word
else: else:
if current_entity: if current_entity:

View File

@ -98,17 +98,17 @@ class ExtractProcessor:
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == "Unstructured": if etl_type == "Unstructured":
if file_extension == ".xlsx" or file_extension == ".xls": if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf": elif file_extension == ".pdf":
extractor = PdfExtractor(file_path) extractor = PdfExtractor(file_path)
elif file_extension in [".md", ".markdown"]: elif file_extension in {".md", ".markdown"}:
extractor = ( extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url) UnstructuredMarkdownExtractor(file_path, unstructured_api_url)
if is_automatic if is_automatic
else MarkdownExtractor(file_path, autodetect_encoding=True) else MarkdownExtractor(file_path, autodetect_encoding=True)
) )
elif file_extension in [".htm", ".html"]: elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension == ".docx": elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
@ -134,13 +134,13 @@ class ExtractProcessor:
else TextExtractor(file_path, autodetect_encoding=True) else TextExtractor(file_path, autodetect_encoding=True)
) )
else: else:
if file_extension == ".xlsx" or file_extension == ".xls": if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf": elif file_extension == ".pdf":
extractor = PdfExtractor(file_path) extractor = PdfExtractor(file_path)
elif file_extension in [".md", ".markdown"]: elif file_extension in {".md", ".markdown"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True) extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in [".htm", ".html"]: elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path) extractor = HtmlExtractor(file_path)
elif file_extension == ".docx": elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)

View File

@ -32,7 +32,7 @@ class FirecrawlApp:
else: else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}') raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
elif response.status_code in [402, 409, 500]: elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred") error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
else: else:

View File

@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor):
multi_select_list = property_value[type] multi_select_list = property_value[type]
for multi_select in multi_select_list: for multi_select in multi_select_list:
value.append(multi_select["name"]) value.append(multi_select["name"])
elif type == "rich_text" or type == "title": elif type in {"rich_text", "title"}:
if len(property_value[type]) > 0: if len(property_value[type]) > 0:
value = property_value[type][0]["plain_text"] value = property_value[type][0]["plain_text"]
else: else:
value = "" value = ""
elif type == "select" or type == "status": elif type in {"select", "status"}:
if property_value[type]: if property_value[type]:
value = property_value[type]["name"] value = property_value[type]["name"]
else: else:

View File

@ -115,7 +115,7 @@ class DatasetRetrieval:
available_datasets.append(dataset) available_datasets.append(dataset)
all_documents = [] all_documents = []
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve( all_documents = self.single_retrieve(
app_id, app_id,
@ -426,7 +426,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"], retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=top_k, top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0) score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"] if retrieval_model["score_threshold_enabled"]
else 0.0, else 0.0,

Some files were not shown because too many files have changed in this diff Show More