diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index e1c0bf33a4..fd98db24b9 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -50,9 +50,18 @@ jobs: - name: Run ModelRuntime run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh + - name: Run dify config tests + run: poetry run -C api python dev/pytest/pytest_config_tests.py + - name: Run Tool run: poetry run -C api bash dev/pytest/pytest_tools.sh + - name: Run mypy + run: | + pushd api + poetry run python -m mypy --install-types --non-interactive . + popd + - name: Set up dotenvs run: | cp docker/.env.example docker/.env diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index bc65c19a91..d3146cd90d 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -9,5 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml +yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase" +echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase" diff --git a/api/.env.example b/api/.env.example index 6dcbae5db2..d5d4e4486f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -60,17 +60,8 @@ DB_DATABASE=dify STORAGE_TYPE=opendal # Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal -STORAGE_OPENDAL_SCHEME=fs -# OpenDAL FS +OPENDAL_SCHEME=fs OPENDAL_FS_ROOT=storage -# OpenDAL S3 -OPENDAL_S3_ROOT=/ -OPENDAL_S3_BUCKET=your-bucket-name -OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com -OPENDAL_S3_ACCESS_KEY_ID=your-access-key -OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key -OPENDAL_S3_REGION=your-region -OPENDAL_S3_SERVER_SIDE_ENCRYPTION= # S3 Storage configuration S3_USE_AWS_MANAGED_IAM=false @@ -313,8 +304,7 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Model configuration -MULTIMODAL_SEND_IMAGE_FORMAT=base64 -MULTIMODAL_SEND_VIDEO_FORMAT=base64 +MULTIMODAL_SEND_FORMAT=base64 PROMPT_GENERATION_MAX_TOKENS=512 CODE_GENERATION_MAX_TOKENS=1024 @@ -409,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 # App configuration @@ -446,3 +437,5 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false # Maximum number of submitted thread count in a ThreadPool for parallel node execution MAX_SUBMIT_COUNT=100 +# Lockout duration in seconds +LOGIN_LOCKOUT_DURATION=86400 \ No newline at end of file diff --git a/api/.ruff.toml b/api/.ruff.toml index 0f3185223c..26a1b977a9 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -70,7 +70,6 @@ ignore = [ "SIM113", # eumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false - "SIM300", # yoda-conditions, ] [lint.per-file-ignores] diff --git a/api/app.py b/api/app.py index 996e2e890f..c6a0829080 100644 --- a/api/app.py +++ b/api/app.py @@ -1,13 +1,30 @@ -from app_factory import create_app -from libs import threadings_utils, version_utils +from libs import version_utils # preparation before creating app version_utils.check_supported_python_version() -threadings_utils.apply_gevent_threading_patch() + + +def is_db_command(): + import sys + + if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": + return True + return False + # create app -app = create_app() -celery = app.extensions["celery"] +if is_db_command(): + from app_factory import create_migrations_app + + app = create_migrations_app() +else: + from app_factory import create_app + from libs import threadings_utils + + threadings_utils.apply_gevent_threading_patch() + + app = create_app() + celery = app.extensions["celery"] if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/app_factory.py b/api/app_factory.py index 7dc08c4d93..c0714116a3 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,5 +1,4 @@ import logging -import os import time from configs import dify_config @@ -17,15 +16,6 @@ def create_flask_app_with_configs() -> DifyApp: dify_app = DifyApp(__name__) dify_app.config.from_mapping(dify_config.model_dump()) - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - return dify_app @@ -98,3 +88,14 @@ def initialize_extensions(app: DifyApp): end_time = time.perf_counter() if dify_config.DEBUG: logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)") + + +def create_migrations_app(): + app = create_flask_app_with_configs() + from extensions import ext_database, ext_migrate + + # Initialize only required extensions + ext_database.init_app(app) + ext_migrate.init_app(app) + + return app diff --git a/api/commands.py b/api/commands.py index 757ee7cb4c..86798567e8 100644 --- a/api/commands.py +++ b/api/commands.py @@ -160,8 +160,7 @@ def migrate_annotation_vector_database(): try: # get apps info apps = ( - db.session.query(App) - .filter(App.status == "normal") + App.query.filter(App.status == "normal") .order_by(App.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -286,8 +285,7 @@ def migrate_knowledge_vector_database(): while True: try: datasets = ( - db.session.query(Dataset) - .filter(Dataset.indexing_technique == "high_quality") + Dataset.query.filter(Dataset.indexing_technique == "high_quality") .order_by(Dataset.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -451,7 +449,8 @@ def convert_to_agent_apps(): if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) + if app is not None: + apps.append(app) if len(apps) == 0: break @@ -556,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str if language not in languages: language = "en-US" - name = name.strip() + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None # generate random password new_password = secrets.token_urlsafe(16) @@ -621,6 +621,10 @@ where sites.id is null limit 1000""" try: app = db.session.query(App).filter(App.id == app_id).first() + if not app: + print(f"App {app_id} not found") + continue + tenant = app.tenant if tenant: accounts = tenant.get_accounts() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f1a1b92e98..fcecb346b0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -297,7 +297,6 @@ class HttpConfig(BaseSettings): ) @computed_field - @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") @@ -308,7 +307,6 @@ class HttpConfig(BaseSettings): ) @computed_field - @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") @@ -491,6 +489,11 @@ class WorkflowConfig(BaseSettings): default=5, ) + WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field( + description="Maximum allowed depth for nested parallel executions", + default=3, + ) + MAX_VARIABLE_SIZE: PositiveInt = Field( description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", default=200 * 1024, @@ -543,6 +546,11 @@ class AuthConfig(BaseSettings): default=60, ) + LOGIN_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ @@ -718,14 +726,9 @@ class IndexingConfig(BaseSettings): ) -class VisionFormatConfig(BaseSettings): - MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( - description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", - default="base64", - ) - - MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field( - description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64", +class MultiModalTransferConfig(BaseSettings): + MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( + description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64", default="base64", ) @@ -768,27 +771,27 @@ class PositionConfig(BaseSettings): default="", ) - @computed_field + @property def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} @@ -833,13 +836,13 @@ class FeatureConfig( FileAccessConfig, FileUploadConfig, HttpConfig, - VisionFormatConfig, InnerAPIConfig, IndexingConfig, LoggingConfig, MailConfig, ModelLoadBalanceConfig, ModerationConfig, + MultiModalTransferConfig, PositionConfig, RagEtlConfig, SecurityConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 9265a48d9b..f6a44eaa47 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -130,7 +130,6 @@ class DatabaseConfig(BaseSettings): ) @computed_field - @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS @@ -168,7 +167,6 @@ class DatabaseConfig(BaseSettings): ) @computed_field - @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { "pool_size": self.SQLALCHEMY_POOL_SIZE, @@ -206,7 +204,6 @@ class CeleryConfig(DatabaseConfig): ) @computed_field - @property def CELERY_RESULT_BACKEND(self) -> str | None: return ( "db+{}".format(self.SQLALCHEMY_DATABASE_URI) @@ -214,7 +211,6 @@ class CeleryConfig(DatabaseConfig): else self.CELERY_BROKER_URL ) - @computed_field @property def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False diff --git a/api/configs/middleware/storage/opendal_storage_config.py b/api/configs/middleware/storage/opendal_storage_config.py index 56a8d24edf..ef38070e53 100644 --- a/api/configs/middleware/storage/opendal_storage_config.py +++ b/api/configs/middleware/storage/opendal_storage_config.py @@ -1,51 +1,9 @@ -from enum import StrEnum -from typing import Literal - from pydantic import Field from pydantic_settings import BaseSettings -class OpenDALScheme(StrEnum): - FS = "fs" - S3 = "s3" - - class OpenDALStorageConfig(BaseSettings): - STORAGE_OPENDAL_SCHEME: str = Field( - default=OpenDALScheme.FS.value, + OPENDAL_SCHEME: str = Field( + default="fs", description="OpenDAL scheme.", ) - # FS - OPENDAL_FS_ROOT: str = Field( - default="storage", - description="Root path for local storage.", - ) - # S3 - OPENDAL_S3_ROOT: str = Field( - default="/", - description="Root path for S3 storage.", - ) - OPENDAL_S3_BUCKET: str = Field( - default="", - description="S3 bucket name.", - ) - OPENDAL_S3_ENDPOINT: str = Field( - default="https://s3.amazonaws.com", - description="S3 endpoint URL.", - ) - OPENDAL_S3_ACCESS_KEY_ID: str = Field( - default="", - description="S3 access key ID.", - ) - OPENDAL_S3_SECRET_ACCESS_KEY: str = Field( - default="", - description="S3 secret access key.", - ) - OPENDAL_S3_REGION: str = Field( - default="", - description="S3 region.", - ) - OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field( - default="", - description="S3 server-side encryption.", - ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 0c2ccd826e..4a168a3fb1 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.13.2", + default="0.14.2", ) COMMIT_SHA: str = Field( diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py index d1f6781ed3..03c64ea00f 100644 --- a/api/configs/remote_settings_sources/apollo/client.py +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -4,6 +4,7 @@ import logging import os import threading import time +from collections.abc import Mapping from pathlib import Path from .python_3x import http_request, makedirs_wrapper @@ -255,8 +256,8 @@ class ApolloClient: logger.info("stopped, long_poll") # add the need for endorsement to the header - def _sign_headers(self, url): - headers = {} + def _sign_headers(self, url: str) -> Mapping[str, str]: + headers: dict[str, str] = {} if self.secret == "": return headers uri = url[len(self.config_url) : len(url)] diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 7e1a196356..c26d8c0186 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,8 +1,9 @@ import json +from collections.abc import Mapping from models.model import AppMode -default_app_templates = { +default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py index c71f1ce5a3..9f762b3135 100644 --- a/api/controllers/common/errors.py +++ b/api/controllers/common/errors.py @@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException class FilenameNotExistsError(HTTPException): code = 400 description = "The specified filename does not exist." + + +class RemoteFileUploadError(HTTPException): + code = 400 + description = "Error uploading remote file." diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 79869916ed..b1ebc444a5 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore parameters__system_parameters = { "image_file_size_limit": fields.Integer, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 9f52b98ae7..8b5378c132 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,6 +3,25 @@ from flask import Blueprint from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi +from .explore.audio import ChatAudioApi, ChatTextApi +from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from .explore.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from .explore.message import ( + MessageFeedbackApi, + MessageListApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from .explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -67,15 +86,81 @@ from .datasets import ( # Import explore controllers from .explore import ( - audio, - completion, - conversation, installed_app, - message, parameter, recommended_app, saved_message, - workflow, +) + +# Explore Audio +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") + +# Explore Completion +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) + +# Explore Conversation +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) + + +# Explore Message +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) +# Explore Workflow +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) # Import tag controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index bbb4284b65..783a98caaa 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -33,7 +33,7 @@ def admin_required(view): if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if dify_config.ADMIN_API_KEY != auth_token: + if auth_token != dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") return view(*args, **kwargs) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b612f7bd96..eb42507c63 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,7 @@ -import flask_restful -from flask_login import current_user +from typing import Any + +import flask_restful # type: ignore +from flask_login import current_user # type: ignore from flask_restful import Resource, fields, marshal_with from sqlalchemy import select from sqlalchemy.orm import Session @@ -46,14 +48,15 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None - token_prefix = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None + token_prefix: str | None = None max_keys = 10 @marshal_with(api_key_list) def get(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = ( @@ -65,6 +68,7 @@ class BaseApiKeyListResource(Resource): @marshal_with(api_key_fields) def post(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_editor: @@ -97,11 +101,12 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None def delete(self, resource_id, api_key_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa5..8d0c5b84af 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d433415894..920cae0d85 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index fd05cbc19b..24f1020c18 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -110,7 +110,7 @@ class AnnotationListApi(Resource): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default=None, type=str) + keyword = request.args.get("keyword", default="", type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index da72b704c7..9cd56cef0b 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,8 +1,8 @@ import uuid from typing import cast -from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index fba7a8a935..47acb47a2c 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 695b8890e3..9d26af276d 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 9896fcaab8..dba41e5c47 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ import logging -import flask_login -from flask_restful import Resource, reqparse +import flask_login # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index a25004be4d..8827f129d9 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,9 +1,9 @@ from datetime import UTC, datetime -import pytz -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +import pytz # pip install pytz +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -77,8 +77,9 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) + # FIXME, the type ignore in this file if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -222,7 +223,7 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at <= end_datetime_utc) if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -234,7 +235,7 @@ class ChatConversationApi(Resource): if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( - query.options(joinedload(Conversation.messages)) + query.options(joinedload(Conversation.messages)) # type: ignore .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) .having(func.count(Message.id) >= args["message_count_gte"]) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1..c0a20b7160 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 9c3cbe4e3e..8518d34a8e 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ import os -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.error import ( diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a15..b5828b6b4b 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,8 @@ import logging -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 8ba195f5a5..8ecc8a9db5 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,8 +1,9 @@ import json +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -26,7 +27,9 @@ class ModelConfigResource(Resource): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, + config=cast(dict, request.json), + app_mode=AppMode.value_of(app_model.mode), ) new_app_model_config = AppModelConfig( @@ -38,9 +41,11 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = ( + original_app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() ) + if original_app_model_config is None: + raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} @@ -65,7 +70,7 @@ class ModelConfigResource(Resource): provider_type=agent_tool_entity.provider_type, identity_id=f"AGENT.{app_model.id}", ) - except Exception as e: + except Exception: continue # get decrypted parameters @@ -97,7 +102,7 @@ class ModelConfigResource(Resource): app_id=app_model.id, agent_tool=agent_tool_entity, ) - except Exception as e: + except Exception: continue manager = ToolParameterConfigurationManager( diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 47b58396a1..dd25af8ebf 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,5 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore +from werkzeug.exceptions import BadRequest from controllers.console import api from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist @@ -26,7 +27,7 @@ class TraceAppConfigApi(Resource): return {"has_not_configured": True} return trace_config except Exception as e: - raise e + raise BadRequest(str(e)) @setup_required @login_required @@ -48,7 +49,7 @@ class TraceAppConfigApi(Resource): raise TracingConfigCheckError() return result except Exception as e: - raise e + raise BadRequest(str(e)) @setup_required @login_required @@ -68,7 +69,7 @@ class TraceAppConfigApi(Resource): raise TracingConfigNotExist() return {"result": "success"} except Exception as e: - raise e + raise BadRequest(str(e)) @setup_required @login_required @@ -85,7 +86,7 @@ class TraceAppConfigApi(Resource): raise TracingConfigNotExist() return {"result": "success"} except Exception as e: - raise e + raise BadRequest(str(e)) api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 407f689819..db29b95c41 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,7 +1,7 @@ from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language @@ -50,7 +50,7 @@ class AppSite(Resource): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() + site = Site.query.filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ "title", diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index db5e282409..3b21108cea 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -3,8 +3,8 @@ from decimal import Decimal import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 94f7f2b7a9..291a6e5dd6 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -2,10 +2,11 @@ import json import logging from flask import abort, request -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from configs import dify_config from controllers.console import api from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model @@ -460,7 +461,21 @@ class ConvertToWorkflowApi(Resource): } +class WorkflowConfigApi(Resource): + """Resource for workflow configuration.""" + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + return { + "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + } + + api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(WorkflowConfigApi, "/apps//workflows/draft/config") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 2940556f84..882c53e4fb 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9..25a99c1e15 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707b..097bf7d188 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,8 +3,8 @@ from decimal import Decimal import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index f84c592bba..9ad8c15847 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,8 +5,7 @@ from typing import Optional, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models import App -from models.model import AppMode +from models import App, AppMode def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index d2aa7c903b..c56f551d49 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,14 +1,14 @@ import datetime from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.helper import StrLen, email, extract_remote_ip, timezone -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -27,7 +27,7 @@ class ActivateCheckApi(Resource): invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: data = invitation.get("data", {}) - tenant: Tenant = invitation.get("tenant", None) + tenant = invitation.get("tenant", None) workspace_name = tenant.name if tenant else None workspace_id = tenant.id if tenant else None invitee_email = data.get("email") if data else None diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 465c44e9b6..ea00c2b8c2 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index faca67bb17..e911c9a5e5 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,8 +2,8 @@ import logging import requests from flask import current_app, redirect, request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,8 +17,8 @@ from ..wraps import account_initialization_required, setup_required def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( - client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, + client_id=dify_config.NOTION_CLIENT_ID or "", + client_secret=dify_config.NOTION_CLIENT_SECRET or "", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", ) diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index a0b9faa781..b9ce5d644d 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -126,8 +126,8 @@ class ForgotPasswordResetApi(Resource): else: try: account = AccountService.create_account_and_tenant( - email=reset_data.get("email"), - name=reset_data.get("email"), + email=reset_data.get("email", ""), + name=reset_data.get("email", ""), password=password_confirm, interface_language=languages[0], ) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f4463ce9cb..78a80fc8d7 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,8 +1,8 @@ from typing import cast -import flask_login +import flask_login # type: ignore from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore import services from constants.languages import languages diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 123046cf62..8e54da4ef6 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -78,8 +78,9 @@ class OAuthCallback(Resource): try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) - except requests.exceptions.HTTPError as e: - logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + except requests.exceptions.RequestException as e: + error_text = e.response.text if e.response else str(e) + logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): @@ -131,7 +132,7 @@ class OAuthCallback(Resource): def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account = Account.get_by_openid(provider, user_info.id) + account: Optional[Account] = Account.get_by_openid(provider, user_info.id) if not account: with Session(db.engine) as session: diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c..fd7b7bd8cb 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 95d4013e3a..f3c3736b25 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restful # type: ignore from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 94ffc73252..c236e1a431 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,10 +1,11 @@ import logging from argparse import ArgumentTypeError from datetime import UTC, datetime +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import asc, desc from werkzeug.exceptions import Forbidden, NotFound @@ -161,7 +162,7 @@ class DatasetDocumentListApi(Resource): f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 " f"(case insensitive)." ) - except (ArgumentTypeError, ValueError, Exception) as e: + except (ArgumentTypeError, ValueError, Exception): fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: @@ -749,8 +750,7 @@ class DocumentMetadataApi(DocumentResource): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - - metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": @@ -964,7 +964,7 @@ class DocumentRetryApi(DocumentResource): if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) - except Exception as e: + except Exception: logging.exception(f"Failed to retry document, document id: {document_id}") continue # retry document diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 6f7ef86d2c..2d5933ca23 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,8 +3,8 @@ from datetime import UTC, datetime import pandas as pd from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index bc6e3687c1..48f360dcd1 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 495f511275..18b746f547 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c076863..bd944602c1 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 9127c8af45..da995537e7 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 9690677f61..c7f9fec326 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -4,7 +4,6 @@ from flask import request from werkzeug.exceptions import InternalServerError import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -67,7 +66,7 @@ class ChatAudioApi(InstalledAppResource): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore app_model = installed_app.app try: @@ -118,9 +117,3 @@ class ChatTextApi(InstalledAppResource): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") -api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") -# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', -# endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 85c43f8101..3331ded70f 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,12 +1,11 @@ import logging from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import reqparse +from flask_login import current_user # type: ignore +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -147,21 +146,3 @@ class ChatStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 - - -api.add_resource( - CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" -) -api.add_resource( - CompletionStopApi, - "/installed-apps//completion-messages//stop", - endpoint="installed_app_stop_completion", -) -api.add_resource( - ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" -) -api.add_resource( - ChatStopApi, - "/installed-apps//chat-messages//stop", - endpoint="installed_app_stop_chat_completion", -) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 6f9d7769b9..91916cbc1e 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,12 +1,13 @@ -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -34,14 +35,16 @@ class ConversationListApi(InstalledAppResource): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=current_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=current_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.EXPLORE, + pinned=pinned, + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -114,28 +117,3 @@ class ConversationUnPinApi(InstalledAppResource): WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} - - -api.add_resource( - ConversationRenameApi, - "/installed-apps//conversations//name", - endpoint="installed_app_conversation_rename", -) -api.add_resource( - ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" -) -api.add_resource( - ConversationApi, - "/installed-apps//conversations/", - endpoint="installed_app_conversation", -) -api.add_resource( - ConversationPinApi, - "/installed-apps//conversations//pin", - endpoint="installed_app_conversation_pin", -) -api.add_resource( - ConversationUnPinApi, - "/installed-apps//conversations//unpin", - endpoint="installed_app_conversation_unpin", -) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3de179164d..86550b2bdf 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,8 +1,9 @@ from datetime import UTC, datetime +from typing import Any from flask import request -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -34,7 +35,7 @@ class InstalledAppsListApi(Resource): installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - installed_apps = [ + installed_app_list: list[dict[str, Any]] = [ { "id": installed_app.id, "app": installed_app.app, @@ -47,7 +48,7 @@ class InstalledAppsListApi(Resource): for installed_app in installed_apps if installed_app.app is not None ] - installed_apps.sort( + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], app["last_used_at"] is None, @@ -55,7 +56,7 @@ class InstalledAppsListApi(Resource): ) ) - return {"installed_apps": installed_apps} + return {"installed_apps": installed_app_list} @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3d221ff30a..c3488de299 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -70,7 +69,7 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -153,21 +152,3 @@ class MessageSuggestedQuestionApi(InstalledAppResource): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") -api.add_resource( - MessageFeedbackApi, - "/installed-apps//messages//feedbacks", - endpoint="installed_app_message_feedback", -) -api.add_resource( - MessageMoreLikeThisApi, - "/installed-apps//messages//more-like-this", - endpoint="installed_app_more_like_this", -) -api.add_resource( - MessageSuggestedQuestionApi, - "/installed-apps//messages//suggested-questions", - endpoint="installed_app_suggested_question", -) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fee52248a6..5bc74d16e7 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 5daaa1e7c3..be6b1f5d21 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,9 +1,10 @@ -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from constants.languages import languages from controllers.console import api from controllers.console.wraps import account_initialization_required +from libs.helper import AppIconUrlField from libs.login import login_required from services.recommended_app_service import RecommendedAppService @@ -12,6 +13,8 @@ app_fields = { "name": fields.String, "mode": fields.String, "icon": fields.String, + "icon_type": fields.String, + "icon_url": AppIconUrlField, "icon_background": fields.String, } diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 0fc9637479..9f0c496645 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 45f99b1db9..76d30299cd 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,9 +1,8 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError -from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -73,9 +72,3 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"} - - -api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") -api.add_resource( - InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" -) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a0..b7ba81fba2 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 4ac0aa497e..ed6cedb220 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff865..da1171412f 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 946d3db37f..8cf754bbd6 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,6 +1,9 @@ +from typing import Literal + from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with # type: ignore +from werkzeug.exceptions import Forbidden import services from configs import dify_config @@ -47,7 +50,8 @@ class FileApi(Resource): @cloud_edition_billing_resource_check("documents") def post(self): file = request.files["file"] - source = request.form.get("source") + source_str = request.form.get("source") + source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None if "file" not in request.files: raise NoFileUploadedError() @@ -58,6 +62,9 @@ class FileApi(Resource): if not file.filename: raise FilenameNotExistsError + if source == "datasets" and not current_user.is_dataset_editor: + raise Forbidden() + if source not in ("datasets", None): source = None diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946e..2a116112a3 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index fac1341b39..30afc930a8 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,11 +2,12 @@ import urllib.parse from typing import cast import httpx -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields @@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e1f19a87a3..3b47f8f12f 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index ccd3293a62..da83f64019 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -23,7 +23,7 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get("type", type=str) + tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 7dea8e554e..7773c99944 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import json import logging import requests -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from packaging import version from configs import dify_config diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index f704783cff..96ed4b7a57 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,8 +2,8 @@ import datetime import pytz from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from configs import dify_config from constants.languages import supported_language diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 114905cf1d..6e1d87cb12 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -37,7 +37,7 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( @@ -86,7 +86,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 38ed2316a5..1afb41ea87 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,7 @@ from urllib import parse -from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore import services from configs import dify_config @@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource): @account_initialization_required def delete(self, member_id): member = db.session.query(Account).filter(Account.id == str(member_id)).first() - if not member: + if member is None: abort(404) - - try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) - except services.errors.account.CannotOperateSelfError as e: - return {"code": "cannot-operate-self", "message": str(e)}, 400 - except services.errors.account.NoPermissionError as e: - return {"code": "forbidden", "message": str(e)}, 403 - except services.errors.account.MemberNotInTenantError as e: - return {"code": "member-not-found", "message": str(e)}, 404 - except Exception as e: - raise ValueError(str(e)) + else: + try: + TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + except services.errors.account.CannotOperateSelfError as e: + return {"code": "cannot-operate-self", "message": str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {"code": "forbidden", "message": str(e)}, 403 + except services.errors.account.MemberNotInTenantError as e: + return {"code": "member-not-found", "message": str(e)}, 404 + except Exception as e: + raise ValueError(str(e)) return {"result": "success"}, 204 @@ -122,10 +122,11 @@ class MemberUpdateRoleApi(Resource): return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) - if not member: + if member: abort(404) try: + assert member is not None, "Member not found" TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) except Exception as e: raise ValueError(str(e)) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index b9612c0f9d..d7d1cc8d00 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -66,7 +66,7 @@ class ModelProviderValidateApi(Resource): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.provider_credentials_validate( @@ -133,10 +133,8 @@ class ModelProviderIconApi(Resource): icon_type=icon_type, lang=lang, ) - - if not icon: - return {"message": "Icon not found"}, 404 - + if icon is None: + raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}") return send_file(io.BytesIO(icon), mimetype=mimetype) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index a03e875671..8b72a1ea3d 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -308,7 +308,7 @@ class ModelProviderModelValidateApi(Resource): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.model_credentials_validate( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d18e0ba949..4b21afb8b6 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,14 +1,16 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required from services.tools.api_tools_manage_service import ApiToolManageService @@ -112,12 +114,16 @@ class ToolBuiltinProviderUpdateApi(Resource): args = parser.parse_args() - return BuiltinToolManageService.update_builtin_tool_provider( - user_id, - tenant_id, - provider, - args["credentials"], - ) + with Session(db.engine) as session: + result = BuiltinToolManageService.update_builtin_tool_provider( + session=session, + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + ) + session.commit() + return result class ToolBuiltinProviderGetCredentialsApi(Resource): @@ -125,15 +131,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user = current_user - - user_id = user.id - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id return BuiltinToolManageService.get_builtin_tool_provider_credentials( - user_id, - tenant_id, - provider, + tenant_id=tenant_id, + provider_name=provider, ) @@ -329,7 +331,6 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): user = current_user - user_id = user.id tenant_id = user.current_tenant_id return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 76d76f6b58..0f99bf62e3 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,8 @@ import logging from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Unauthorized import services @@ -82,11 +82,7 @@ class WorkspaceListApi(Resource): parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = ( - db.session.query(Tenant) - .order_by(Tenant.created_at.desc()) - .paginate(page=args["page"], per_page=args["limit"]) - ) + tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"]) has_more = False if len(tenants.items) == args["limit"]: @@ -151,6 +147,8 @@ class SwitchWorkspaceApi(Resource): raise AccountNotLinkTenantError("Account not link tenant") new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + if new_tenant is None: + raise ValueError("Tenant not found") return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} @@ -166,7 +164,7 @@ class CustomConfigWorkspaceApi(Resource): parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 0fbf450046..a6f64700f2 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -3,7 +3,7 @@ import os from functools import wraps from flask import abort, request -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError @@ -122,8 +122,8 @@ def cloud_utm_record(view): utm_info = request.cookies.get("utm_info") if utm_info: - utm_info = json.loads(utm_info) - OperationService.record_utm(current_user.current_tenant_id, utm_info) + utm_info_dict: dict = json.loads(utm_info) + OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) except Exception as e: pass return view(*args, **kwargs) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 6b3ac93cdf..2357288a50 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,5 +1,5 @@ from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound import services diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index a298701a2f..cfcce81247 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,5 +1,5 @@ from flask import Response -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 64cb5e54ff..58d48fe361 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index b543b5e164..c3d5386e3d 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -45,16 +45,14 @@ def enterprise_inner_api_user_auth(view): if " " in user_id: user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get("X-Inner-Api-Key") - if not inner_api_key: - raise ValueError("inner api key not found") + inner_api_key = request.headers.get("X-Inner-Api-Key", "") data_to_sign = f"DIFY {user_id}" signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) - signature = b64encode(signature.digest()).decode("utf-8") + signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature != token: + if signature_base64 != token: return view(*args, **kwargs) kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecff7d07e9..8388e2045d 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 5db4163647..e6bcc0bfd2 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services @@ -83,7 +83,7 @@ class TextApi(Resource): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 8d8e356c4c..1be54b386b 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c62fd77d36..334f2c5620 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,6 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound import services @@ -7,6 +8,7 @@ from controllers.service_api import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import ( conversation_delete_fields, conversation_infinite_scroll_pagination_fields, @@ -39,14 +41,16 @@ class ConversationApi(Resource): args = parser.parse_args() try: - return ConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return ConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], + ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef..27b21b9f50 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index ada40ec9cb..522c7509b9 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 96d1337632..c7dd4de345 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError from controllers.service_api import api diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 799fccc228..d6a3beb6b8 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound import services.dataset_service diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5c3fc7b241..34afe2837f 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,7 +1,7 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from sqlalchemy import desc from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index e68f6b4dc4..34904574a8 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.service_api import api diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index d24c4597e2..75d9141a6d 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from configs import dify_config from controllers.service_api import api diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2128c4c53f..740b92ef8e 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -5,8 +5,8 @@ from functools import wraps from typing import Optional from flask import current_app, request -from flask_login import user_logged_in -from flask_restful import Resource +from flask_login import user_logged_in # type: ignore +from flask_restful import Resource # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden, Unauthorized @@ -49,6 +49,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio raise Forbidden("The app's API service has been disabled.") tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant is None: + raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") @@ -154,8 +156,8 @@ def validate_dataset_token(view=None): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cc8255ccf4..20e071c834 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index e8521307ad..97d980d07c 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore try: parser = reqparse.RequestParser() @@ -82,7 +82,7 @@ class TextApi(WebApiResource): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 45b890dfc4..761771a81a 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index c3b0cd4f44..28feb1ca47 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,11 +1,13 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.WEB_APP, + pinned=pinned, + sort_by=args["sort_by"], + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed2238..ce841a8814 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index a282fc63a8..1d4474015a 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError @@ -33,7 +33,7 @@ class FileApi(WebApiResource): content=file.read(), mimetype=file.mimetype, user=end_user, - source=source, + source="datasets" if source == "datasets" else None, ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 98891f5d00..0f47e64370 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index a01ffd8612..4625c1f43d 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,7 +1,7 @@ import uuid from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index d6b8eb2855..d559ab8e07 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,10 +1,11 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restful import marshal_with, reqparse # type: ignore import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy @@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b0492e6b6f..6a9b818907 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea3..e68dc7aa4a 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 55b0c3e2ab..48d25e720c 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError from controllers.web import api diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df18..1b4d263bee 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebSSOAuthRequiredError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 6d03b09c87..2e270e304d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,6 @@ import json import logging import uuid -from collections.abc import Mapping, Sequence from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -49,6 +48,7 @@ logger = logging.getLogger(__name__) class BaseAgentRunner(AppRunner): def __init__( self, + *, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, conversation: Conversation, @@ -109,7 +109,7 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query = None + self.query: Optional[str] = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( @@ -158,7 +158,7 @@ class BaseAgentRunner(AppRunner): continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -203,7 +203,7 @@ class BaseAgentRunner(AppRunner): return prompt_tool - def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: """ Init tools """ @@ -251,7 +251,7 @@ class BaseAgentRunner(AppRunner): continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -401,7 +401,7 @@ class BaseAgentRunner(AppRunner): """ Organize agent history """ - result = [] + result: list[PromptMessage] = [] # check if there is a system message in the beginning of the conversation for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 8b510258e8..810c5f3893 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,8 @@ import json from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Optional, Union +from typing import Optional +from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -37,8 +38,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): self, message: Message, query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + inputs: Mapping[str, str], + ) -> Generator: """ Run Cot agent application """ @@ -63,7 +64,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -94,7 +95,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): # the last iteration, remove all tools self._prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids @@ -109,7 +110,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model - chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + chunks = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, tools=[], @@ -161,7 +162,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): # get llm usage if "usage" in usage_dict: - increase_usage(llm_usage, usage_dict["usage"]) + if usage_dict["usage"] is not None: + increase_usage(llm_usage, usage_dict["usage"]) else: usage_dict["usage"] = LLMUsage.empty_usage() @@ -170,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation="", answer=scratchpad.agent_response or "", messages_ids=[], @@ -213,7 +215,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): agent_thought=agent_thought, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation={scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, @@ -251,7 +253,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): answer=final_answer, messages_ids=[], ) - # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -325,14 +326,14 @@ class CotAgentRunner(BaseAgentRunner, ABC): """ return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: """ fill in inputs from external data tools """ for key, value in inputs.items(): try: instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) - except Exception as e: + except Exception: continue return instruction @@ -403,6 +404,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): if current_scratchpad: assert isinstance(message.content, str) current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") elif isinstance(message, UserPromptMessage): if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 9e12092b3d..7d407a4976 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -23,6 +23,8 @@ class CotChatAgentRunner(CotAgentRunner): assert self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt + if not prompt_entity: + raise ValueError("Agent prompt configuration is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -78,6 +80,7 @@ class CotChatAgentRunner(CotAgentRunner): assistant_messages = [] else: assistant_message = AssistantPromptMessage(content="") + assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str for unit in agent_scratchpad: if unit.is_final(): assert isinstance(assistant_message.content, str) diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 0563090537..3a4d31e047 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -2,7 +2,12 @@ import json from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder @@ -11,7 +16,11 @@ class CotCompletionAgentRunner(CotAgentRunner): """ Organize instruction prompt """ + if self.app_config.agent is None: + raise ValueError("Agent configuration is not set") prompt_entity = self.app_config.agent.prompt + if prompt_entity is None: + raise ValueError("prompt entity is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -33,7 +42,13 @@ class CotCompletionAgentRunner(CotAgentRunner): if isinstance(message, UserPromptMessage): historic_prompt += f"Question: {message.content}\n\n" elif isinstance(message, AssistantPromptMessage): - historic_prompt += message.content + "\n\n" + if isinstance(message.content, str): + historic_prompt += message.content + "\n\n" + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data return historic_prompt @@ -50,7 +65,7 @@ class CotCompletionAgentRunner(CotAgentRunner): # organize current assistant messages agent_scratchpad = self._agent_scratchpad assistant_prompt = "" - for unit in agent_scratchpad: + for unit in agent_scratchpad or []: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index a63d92c1ae..f45fa5c66e 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -40,6 +40,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): app_generate_entity = self.application_generate_entity app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -77,7 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): # the last iteration, remove all tools prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -118,13 +120,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk)) + tool_calls.extend(self.extract_tool_calls(chunk) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False ) - except json.JSONDecodeError as e: + except json.JSONDecodeError: # ensure ascii to avoid encoding error tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) @@ -133,7 +135,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): for content in chunk.delta.message.content: response += content.data else: - response += chunk.delta.message.content + response += str(chunk.delta.message.content) if chunk.delta.usage: increase_usage(llm_usage, chunk.delta.usage) @@ -145,13 +147,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): # check if there is any tool call if self.check_blocking_tool_calls(result): function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result)) + tool_calls.extend(self.extract_blocking_tool_calls(result) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False ) - except json.JSONDecodeError as e: + except json.JSONDecodeError: # ensure ascii to avoid encoding error tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) @@ -164,7 +166,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): for content in result.message.content: response += content.data else: - response += result.message.content + response += str(result.message.content) if not result.message.content: result.message.content = "" @@ -265,7 +267,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response["tool_response"], + content=str(tool_response["tool_response"]), tool_call_id=tool_call_id, name=tool_call_name, ) @@ -275,9 +277,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name=None, - tool_input=None, - thought=None, + tool_name="", + tool_input="", + thought="", tool_invoke_meta={ tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, @@ -285,7 +287,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, - answer=None, + answer="", messages_ids=message_file_ids, ) self.queue_manager.publish( @@ -386,9 +388,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - return prompt_messages + return prompt_messages or [] - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ @@ -446,7 +448,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query, []) + query_prompt_messages = self._organize_user_query(self.query or "", []) self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 085bac8601..61fa774ea5 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -38,7 +38,7 @@ class CotAgentOutputParser: except: return json_str or "" - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return @@ -67,15 +67,15 @@ class CotAgentOutputParser: for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage - response = response.delta.message.content - if not isinstance(response, str): + response_content = response.delta.message.content + if not isinstance(response_content, str): continue # stream index = 0 - while index < len(response): + while index < len(response_content): steps = 1 - delta = response[index : index + steps] + delta = response_content[index : index + steps] yield_delta = False if delta == "`": diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index b9aae7904f..646c4badb9 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -66,6 +66,8 @@ class DatasetConfigManager: dataset_configs = config.get("dataset_configs") else: dataset_configs = {"retrieval_model": "multiple"} + if dataset_configs is None: + return None query_variable = config.get("dataset_query_variable") if dataset_configs["retrieval_model"] == "single": diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index d3077aefd5..b19865ff4c 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -104,7 +104,7 @@ class ModelConfigManager: config["model"]["completion_params"] ) - return config, ["model"] + return dict(config), ["model"] @classmethod def validate_model_completion_params(cls, cp: dict) -> dict: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index b4dacbc409..92b4185abf 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -7,10 +7,10 @@ class OpeningStatementConfigManager: :param config: model config args """ # opening statement - opening_statement = config.get("opening_statement") + opening_statement = config.get("opening_statement", "") # suggested questions - suggested_questions_list = config.get("suggested_questions") + suggested_questions_list = config.get("suggested_questions", []) return opening_statement, suggested_questions_list diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index c9adbeb964..684f2bc8a3 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,6 +29,7 @@ from factories import file_factory from models.account import Account from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -145,7 +146,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -323,6 +324,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AdvancedChatAppRunner( diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 18b115dfe4..a506447671 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -4,14 +4,19 @@ import logging import queue import re import threading +from collections.abc import Iterable +from typing import Optional from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentMessageEvent, QueueLLMChunkEvent, QueueNodeSucceededEvent, QueueTextChunkEvent, + WorkflowQueueMessage, ) -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.model_runtime.entities.model_entities import ModelType @@ -21,7 +26,7 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -29,13 +34,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): ) -def _process_future(future_queue, audio_queue): +def _process_future( + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None], + audio_queue: queue.Queue[AudioTrunk], +): while True: try: future = future_queue.get() if future is None: break - for audio in future.result(): + invoke_result = future.result() + if not invoke_result: + continue + for audio in invoke_result: audio_base64 = base64.b64encode(bytes(audio)) audio_queue.put(AudioTrunk("responding", audio=audio_base64)) except Exception as e: @@ -49,8 +60,8 @@ class AppGeneratorTTSPublisher: self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" - self._audio_queue = queue.Queue() - self._msg_queue = queue.Queue() + self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() + self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( @@ -62,18 +73,16 @@ class AppGeneratorTTSPublisher: if not voice or voice not in values: self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 - self._last_audio_event = None - self._runtime_thread = threading.Thread(target=self._runtime).start() + self._last_audio_event: Optional[AudioTrunk] = None + # FIXME better way to handle this threading.start + threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) - def publish(self, message): - try: - self._msg_queue.put(message) - except Exception as e: - self.logger.warning(e) + def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): + self._msg_queue.put(message) def _runtime(self): - future_queue = queue.Queue() + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue() threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() while True: try: @@ -86,10 +95,21 @@ class AppGeneratorTTSPublisher: future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): - self.msg_text += message.event.chunk.delta.message.content + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + if isinstance(message_content, str): + self.msg_text += message_content + elif isinstance(message_content, list): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): + if message.event.outputs is None: + continue self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) @@ -110,16 +130,15 @@ class AppGeneratorTTSPublisher: break future_queue.put(None) - def check_and_get_audio(self) -> AudioTrunk | None: + def check_and_get_audio(self): try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: self.executor.shutdown(wait=False) - return self.last_message + return self._last_audio_event audio = self._audio_queue.get_nowait() if audio and audio.status == "finish": self.executor.shutdown(wait=False) - self._runtime_thread = None if audio: self._last_audio_event = audio return audio diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index cf0c9d7593..6339d79898 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -109,18 +109,18 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: + db_conversation_variables = session.scalars(stmt).all() + if not db_conversation_variables: # Create conversation variables if they don't exist. - conversation_variables = [ + db_conversation_variables = [ ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) for variable in workflow.conversation_variables ] - session.add_all(conversation_variables) + session.add_all(db_conversation_variables) # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] + conversation_variables = [item.to_variable() for item in db_conversation_variables] session.commit() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 2767f8a642..5e924723b6 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,6 +2,7 @@ import json import logging import time from collections.abc import Generator, Mapping +from threading import Thread from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -23,6 +24,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -64,6 +66,7 @@ from models.enums import CreatedByRole from models.workflow import ( Workflow, WorkflowNodeExecution, + WorkflowRun, WorkflowRunStatus, ) @@ -81,6 +84,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] + _conversation_name_generate_thread: Optional[Thread] = None def __init__( self, @@ -132,7 +136,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] - def process(self): + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: @@ -181,7 +185,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise ValueError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -198,11 +202,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -223,7 +227,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -263,8 +267,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc :return: """ # init fake graph runtime state - graph_runtime_state = None - workflow_run = None + graph_runtime_state: Optional[GraphRuntimeState] = None + workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -292,13 +296,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - elif isinstance(event, QueueNodeStartedEvent): + elif isinstance( + event, + QueueNodeRetryEvent, + ): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - - response = self._workflow_node_start_to_stream_response( + response = self._workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -306,6 +314,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response: yield response + elif isinstance(event, QueueNodeStartedEvent): + if not workflow_run: + raise ValueError("workflow run not initialized.") + + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) + + response_start = self._workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response_start: + yield response_start elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) @@ -313,18 +335,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if response_finish: + yield response_finish elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -332,47 +354,48 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response: yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("workflow run not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -391,10 +414,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_partial_success( workflow_run=workflow_run, @@ -414,10 +437,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, @@ -496,7 +519,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._message_to_stream_response( @@ -507,7 +530,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) if output_moderation_answer: @@ -593,7 +616,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + files=self._recorded_files, + metadata=extras.get("metadata", {}), ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 417d23eccf..55b6ee510f 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -61,7 +61,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b596c99c9f..20edf5e973 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -24,6 +24,7 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -154,7 +155,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -201,8 +202,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): user=user, stream=streaming, ) - - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # FIXME: Type hinting issue here, ignore it for now, will fix it later + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore def _generate_worker( self, @@ -230,6 +231,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AgentChatAppRunner() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d4b327922e..188b37d679 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -196,10 +196,15 @@ class AgentChatAppRunner(AppRunner): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - message = db.session.query(Message).filter(Message.id == message.id).first() + conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + if conversation_result is None: + raise ValueError("Conversation not found") + message_result = db.session.query(Message).filter(Message.id == message.id).first() + if message_result is None: + raise ValueError("Message not found") db.session.close() + runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode @@ -217,12 +222,12 @@ class AgentChatAppRunner(AppRunner): runner = runner_cls( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=conversation_result, app_config=app_config, model_config=application_generate_entity.model_conf, config=agent_entity, queue_manager=queue_manager, - message=message, + message=message_result, user_id=application_generate_entity.user_id, memory=memory, prompt_messages=prompt_message, diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 99c95b296a..82ec33b269 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -15,7 +15,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 7cb3387876..b206dabbc1 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,7 +1,6 @@ import queue import time from abc import abstractmethod -from collections.abc import Generator from enum import Enum from typing import Any, Optional @@ -11,9 +10,11 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueErrorEvent, QueuePingEvent, QueueStopEvent, + WorkflowQueueMessage, ) from extensions.ext_redis import redis_client @@ -37,11 +38,11 @@ class AppQueueManager: AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) - q = queue.Queue() + q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q - def listen(self) -> Generator: + def listen(self): """ Listen to queue :return: @@ -49,7 +50,7 @@ class AppQueueManager: # wait for APP_MAX_EXECUTION_TIME seconds to stop listen listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() - last_ping_time = 0 + last_ping_time: int | float = 0 while True: try: message = self._q.get(timeout=1) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 609fd03f22..07a248d77a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,5 +1,5 @@ import time -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -36,8 +36,8 @@ class AppRunner: app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, ) -> int: """ @@ -64,7 +64,7 @@ class AppRunner: ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -85,7 +85,7 @@ class AppRunner: prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - rest_tokens = model_context_tokens - max_tokens - prompt_tokens + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: raise InvokeBadRequestError( "Query or prefix prompt is too long, you can reduce the prefix prompt, " @@ -111,7 +111,7 @@ class AppRunner: ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -136,8 +136,8 @@ class AppRunner: app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, @@ -156,6 +156,7 @@ class AppRunner: """ # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform] prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( app_mode=AppMode.value_of(app_record.mode), @@ -171,8 +172,11 @@ class AppRunner: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + if not advanced_completion_prompt_template: + raise InvokeBadRequestError("Advanced completion prompt template is required.") prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: @@ -181,6 +185,8 @@ class AppRunner: assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: + if not prompt_template_entity.advanced_chat_prompt_template: + raise InvokeBadRequestError("Advanced chat prompt template is required.") prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) @@ -246,7 +252,7 @@ class AppRunner: def _handle_invoke_result( self, - invoke_result: Union[LLMResult, Generator], + invoke_result: Union[LLMResult, Generator[Any, None, None]], queue_manager: AppQueueManager, stream: bool, agent: bool = False, @@ -259,10 +265,12 @@ class AppRunner: :param agent: agent :return: """ - if not stream: + if not stream and isinstance(invoke_result, LLMResult): self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - else: + elif stream and isinstance(invoke_result, Generator): self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + else: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") def _handle_invoke_result_direct( self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool @@ -291,8 +299,8 @@ class AppRunner: :param agent: agent :return: """ - model = None - prompt_messages = [] + model: str = "" + prompt_messages: list[PromptMessage] = [] text = "" usage = None for result in invoke_result: @@ -328,13 +336,14 @@ class AppRunner: def moderation_for_inputs( self, + *, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, inputs: Mapping[str, Any], - query: str, + query: str | None = None, message_id: str, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -350,7 +359,7 @@ class AppRunner: app_id=app_id, tenant_id=tenant_id, app_config=app_generate_entity.app_config, - inputs=inputs, + inputs=dict(inputs), query=query or "", message_id=message_id, trace_manager=app_generate_entity.trace_manager, @@ -390,9 +399,9 @@ class AppRunner: tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index fc2e771690..6a7f4312ed 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -24,6 +24,7 @@ from extensions.ext_database import db from factories import file_factory from models.account import Account from models.model import App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -146,7 +147,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, invoke_from=invoke_from, @@ -216,6 +217,8 @@ class ChatAppGenerator(MessageBasedAppGenerator): # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = ChatAppRunner() diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index f67df2f1ad..bfaefeb8cb 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -15,7 +15,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 1193c4b7a4..02e5d47568 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -42,7 +42,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 79934db984..e07744c5c0 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -83,8 +83,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): query = query.replace("\x00", "") inputs = args["inputs"] - extras = {} - # get conversation conversation = None @@ -99,7 +97,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # parse files @@ -132,11 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras=extras, + extras={}, trace_manager=trace_manager, ) @@ -157,7 +155,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, @@ -197,6 +195,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): try: # get message message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError() # chatbot app runner = CompletionAppRunner() @@ -293,7 +293,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): model_conf=ModelConfigConverter.convert(app_config), inputs=message.inputs, query=message.query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=stream, invoke_from=invoke_from, @@ -317,7 +317,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 908d74ff53..41278b75b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -76,7 +76,7 @@ class CompletionAppRunner(AppRunner): tenant_id=app_config.tenant_id, app_generate_entity=application_generate_entity, inputs=inputs, - query=query, + query=query or "", message_id=message.id, ) except ModerationError as e: @@ -122,7 +122,7 @@ class CompletionAppRunner(AppRunner): tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_conf, config=dataset_config, - query=query, + query=query or "", invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 6f8d0894d5..89dda03da1 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -15,7 +15,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = CompletionAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -35,7 +35,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 95ae798ec1..c2e35faf89 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,11 +2,11 @@ import json import logging from collections.abc import Generator from datetime import UTC, datetime -from typing import Optional, Union +from typing import Optional, Union, cast from sqlalchemy import and_ -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( @@ -42,7 +42,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, ], queue_manager: AppQueueManager, conversation: Conversation, @@ -144,7 +144,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): :conversation conversation :return: """ - app_config = application_generate_entity.app_config + app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config) # get from source end_user_id = None @@ -267,7 +267,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): except KeyError: pass - return introduction + return introduction or "" def _get_conversation(self, conversation_id: str): """ @@ -282,7 +282,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): return conversation - def _get_message(self, message_id: str) -> Message: + def _get_message(self, message_id: str) -> Optional[Message]: """ Get message by message id :param message_id: message id diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 910a99f05f..9a5f90f998 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -116,7 +116,7 @@ class WorkflowAppGenerator(BaseAppGenerator): inputs=self._prepare_user_inputs( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), - files=system_files, + files=list(system_files), user_id=user.id, stream=streaming, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 72357e5e0c..cba7dc96fb 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -16,16 +16,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.to_dict() + return dict(blocking_response.to_dict()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index d701bbf83c..af0698d701 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -19,6 +19,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -157,7 +158,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise ValueError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -173,11 +174,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -198,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -256,9 +257,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + elif isinstance( + event, + QueueNodeRetryEvent, + ): + if not workflow_run: + raise ValueError("workflow run not initialized.") + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response elif isinstance(event, QueueNodeStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) @@ -289,50 +308,50 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if node_failed_response: yield node_failed_response + elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_parallel_branch_finished_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationStartEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationNextEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_next_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueIterationCompletedEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) elif isinstance(event, QueueWorkflowSucceededEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, @@ -352,10 +371,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa ) elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_partial_success( workflow_run=workflow_run, @@ -376,10 +395,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa ) elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): if not workflow_run: - raise Exception("Workflow run not initialized.") + raise ValueError("workflow run not initialized.") if not graph_runtime_state: - raise Exception("Graph runtime state not initialized.") + raise ValueError("graph runtime state not initialized.") workflow_run = self._handle_workflow_run_failed( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, @@ -407,7 +426,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index b6f0a27f92..8dde1acfa4 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -24,6 +25,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( AgentLogEvent, @@ -40,6 +42,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -188,6 +191,40 @@ class WorkflowBasedAppRunner(AppRunner): ) elif isinstance(event, GraphRunFailedEvent): self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) + elif isinstance(event, NodeRunRetryEvent): + node_run_result = event.route_node_state.node_run_result + inputs: Mapping[str, Any] | None = {} + process_data: Mapping[str, Any] | None = {} + outputs: Mapping[str, Any] | None = {} + execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + self._publish_event( + QueueNodeRetryEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, + inputs=inputs, + process_data=process_data, + outputs=outputs, + error=event.error, + execution_metadata=execution_metadata, + retry_index=event.retry_index, + ) + ) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -207,6 +244,17 @@ class WorkflowBasedAppRunner(AppRunner): ) ) elif isinstance(event, NodeRunSucceededEvent): + node_run_result = event.route_node_state.node_run_result + if node_run_result: + inputs = node_run_result.inputs + process_data = node_run_result.process_data + outputs = node_run_result.outputs + execution_metadata = node_run_result.metadata + else: + inputs = {} + process_data = {} + outputs = {} + execution_metadata = {} self._publish_event( QueueNodeSucceededEvent( node_execution_id=event.id, @@ -218,18 +266,10 @@ class WorkflowBasedAppRunner(AppRunner): parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result - else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result - else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result - else {}, - execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result - else {}, + inputs=inputs, + process_data=process_data, + outputs=outputs, + execution_metadata=execution_metadata, in_iteration_id=event.in_iteration_id, ) ) @@ -251,7 +291,7 @@ class WorkflowBasedAppRunner(AppRunner): process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, error=event.route_node_state.node_run_result.error @@ -311,7 +351,7 @@ class WorkflowBasedAppRunner(AppRunner): process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, execution_metadata=event.route_node_state.node_run_result.metadata diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1f6756a4ee..9e9839dad9 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -5,7 +5,7 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL -from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity @@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: AppConfig + app_config: Any file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index e915dea23b..fb5d0fb299 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -45,6 +45,7 @@ class QueueEvent(StrEnum): ERROR = "error" PING = "ping" STOP = "stop" + RETRY = "retry" class AppQueueEvent(BaseModel): @@ -86,9 +87,9 @@ class QueueIterationStartEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None class QueueIterationNextEvent(AppQueueEvent): @@ -140,9 +141,9 @@ class QueueIterationCompletedEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - metadata: Optional[dict[str, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + metadata: Optional[Mapping[str, Any]] = None steps: int = 0 error: Optional[str] = None @@ -305,10 +306,10 @@ class QueueNodeSucceededEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: Optional[str] = None """single iteration duration map""" @@ -330,6 +331,20 @@ class QueueAgentLogEvent(AppQueueEvent): data: Mapping[str, Any] +class QueueNodeRetryEvent(QueueNodeStartedEvent): + """QueueNodeRetryEvent entity""" + + event: QueueEvent = QueueEvent.RETRY + + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + + error: str + retry_index: int # retry index + + class QueueNodeInIterationFailedEvent(AppQueueEvent): """ QueueNodeInIterationFailedEvent entity @@ -353,10 +368,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str @@ -384,10 +399,10 @@ class QueueNodeExceptionEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str @@ -415,10 +430,10 @@ class QueueNodeFailedEvent(AppQueueEvent): """iteration id if node is in iteration""" start_at: datetime - inputs: Optional[dict[str, Any]] = None - process_data: Optional[dict[str, Any]] = None - outputs: Optional[dict[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index b25336eb67..d186038c28 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -52,6 +52,7 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + NODE_RETRY = "node_retry" PARALLEL_BRANCH_STARTED = "parallel_branch_started" PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" @@ -70,7 +71,7 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) @@ -343,6 +344,75 @@ class NodeFinishStreamResponse(StreamResponse): } +class NodeRetryStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + retry_index: int = 0 + + event: StreamEvent = StreamEvent.NODE_RETRY + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + "retry_index": self.data.retry_index, + }, + } + + class ParallelBranchStartStreamResponse(StreamResponse): """ ParallelBranchStartStreamResponse entity @@ -405,8 +475,8 @@ class IterationNodeStartStreamResponse(StreamResponse): title: str created_at: int extras: dict = {} - metadata: dict = {} - inputs: dict = {} + metadata: Mapping = {} + inputs: Mapping = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None @@ -457,15 +527,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[dict] = None + outputs: Optional[Mapping] = None created_at: int extras: Optional[dict] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float total_tokens: int - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping] = None finished_at: int steps: int parallel_id: Optional[str] = None @@ -559,7 +629,7 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 77b6bb554c..83fd3debad 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -58,7 +58,7 @@ class AnnotationReplyFeature: query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) - if documents: + if documents and documents[0].metadata: annotation_id = documents[0].metadata["annotation_id"] score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 8fe1d96b37..dcc2b4e55f 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -17,7 +17,7 @@ class RateLimit: _UNLIMITED_REQUEST_ID = "unlimited_request_id" _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes - _instance_dict = {} + _instance_dict: dict[str, "RateLimit"] = {} def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 51d610e2cb..03a81353d0 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -62,6 +62,7 @@ class BasedGenerateTaskPipeline: """ logger.debug("error: %s", event.error) e = event.error + err: Exception if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") @@ -130,6 +131,7 @@ class BasedGenerateTaskPipeline: rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), queue_manager=self._queue_manager, ) + return None def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 917649f34e..b9f8e7ca56 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,7 @@ import json import logging import time from collections.abc import Generator +from threading import Thread from typing import Optional, Union, cast from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -103,7 +104,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) ) - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Optional[Thread] = None def process( self, @@ -123,7 +124,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + self._conversation, self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -146,7 +147,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: extras["metadata"] = self._task_state.metadata - + response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -154,7 +155,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan id=self._message.id, mode=self._conversation.mode, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -167,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan mode=self._conversation.mode, conversation_id=self._conversation.id, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -177,7 +178,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise RuntimeError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -201,11 +202,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -252,7 +253,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None + self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -269,13 +270,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result + if event.llm_result: + self._task_state.llm_result = event.llm_result else: self._handle_stop(event) # handle output moderation output_moderation_answer = self._handle_output_moderation_when_task_finished( - self._task_state.llm_result.message.content + cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer @@ -292,7 +294,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): - yield self._agent_thought_to_stream_response(event) + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): response = self._message_file_to_stream_response(event) if response: @@ -307,16 +311,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.prompt_messages = chunk.prompt_messages # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: continue - self._task_state.llm_result.message.content += delta_text + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response(cast(str, delta_text), self._message.id) else: - yield self._agent_message_to_stream_response(delta_text, self._message.id) + yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -336,8 +342,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan llm_result = self._task_state.llm_result usage = llm_result.usage - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if not message: + raise Exception(f"Message {self._message.id} not found") + self._message = message + conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + if not conversation: + raise Exception(f"Conversation {self._conversation.id} not found") + self._conversation = conversation self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages @@ -346,7 +358,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit self._message.answer = ( - PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) @@ -374,6 +386,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} + and hasattr(self._application_generate_entity, "conversation_id") and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) @@ -420,7 +433,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + metadata=extras.get("metadata", {}), ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -440,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan :param event: agent thought event :return: """ - agent_thought: MessageAgentThought = ( + agent_thought: Optional[MessageAgentThought] = ( db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py new file mode 100644 index 0000000000..e4b4168d08 --- /dev/null +++ b/api/core/app/task_pipeline/exc.py @@ -0,0 +1,17 @@ +class TaskPipilineError(ValueError): + pass + + +class RecordNotFoundError(TaskPipilineError): + def __init__(self, record_name: str, record_id: str): + super().__init__(f"{record_name} with id {record_id} not found") + + +class WorkflowRunNotFoundError(RecordNotFoundError): + def __init__(self, workflow_run_id: str): + super().__init__("WorkflowRun", workflow_run_id) + + +class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): + def __init__(self, workflow_node_execution_id: str): + super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index e818a090ed..007543f6d0 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -128,7 +128,7 @@ class MessageCycleManage: """ message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() - if message_file: + if message_file and message_file.url is not None: # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 1b6289eb3b..115ef6ca53 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -28,6 +29,7 @@ from core.app.entities.task_entities import ( IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, NodeFinishStreamResponse, + NodeRetryStreamResponse, NodeStartStreamResponse, ParallelBranchFinishedStreamResponse, ParallelBranchStartStreamResponse, @@ -58,6 +60,8 @@ from models.workflow import ( WorkflowRunStatus, ) +from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError + class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] @@ -92,7 +96,7 @@ class WorkflowCycleManage: ) # handle special values - inputs = WorkflowEntry.handle_special_values(inputs) + inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run with Session(db.engine, expire_on_commit=False) as session: @@ -191,7 +195,7 @@ class WorkflowCycleManage: """ workflow_run = self._refetch_workflow_run(workflow_run.id) - outputs = WorkflowEntry.handle_special_values(outputs) + outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value workflow_run.outputs = json.dumps(outputs or {}) @@ -274,9 +278,9 @@ class WorkflowCycleManage: db.session.close() - with Session(db.engine, expire_on_commit=False) as session: - session.add(workflow_run) - session.refresh(workflow_run) + # with Session(db.engine, expire_on_commit=False) as session: + # session.add(workflow_run) + # session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -440,6 +444,59 @@ class WorkflowCycleManage: return workflow_node_execution + def _handle_workflow_node_execution_retried( + self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + created_at = event.start_at + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - created_at).total_seconds() + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) + origin_metadata = { + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + } + merged_metadata = ( + {**jsonable_encoder(event.execution_metadata), **origin_metadata} + if event.execution_metadata is not None + else origin_metadata + ) + execution_metadata = json.dumps(merged_metadata) + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = created_at + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.error = event.error + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution.index = event.node_run_index + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + + return workflow_node_execution + ################################################# # to stream responses # ################################################# @@ -460,7 +517,7 @@ class WorkflowCycleManage: id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=dict(workflow_run.inputs_dict), + inputs=dict(workflow_run.inputs_dict or {}), created_at=int(workflow_run.created_at.timestamp()), ), ) @@ -474,6 +531,12 @@ class WorkflowCycleManage: :param workflow_run: workflow run :return: """ + # Attach WorkflowRun to an active session so "created_by_role" can be accessed. + workflow_run = db.session.merge(workflow_run) + + # Refresh to ensure any expired attributes are fully loaded + db.session.refresh(workflow_run) + created_by = None if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: created_by_account = workflow_run.created_by_account @@ -499,7 +562,7 @@ class WorkflowCycleManage: workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, status=workflow_run.status, - outputs=dict(workflow_run.outputs_dict), + outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, error=workflow_run.error, elapsed_time=workflow_run.elapsed_time, total_tokens=workflow_run.total_tokens, @@ -604,6 +667,51 @@ class WorkflowCycleManage: ), ) + def _workflow_node_retry_to_stream_response( + self, + event: QueueNodeRetryEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + """ + Workflow node finish to stream response. + :param event: queue node succeeded or failed event + :param task_id: task id + :param workflow_node_execution: workflow node execution + :return: + """ + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + + return NodeRetryStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeRetryStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + ), + ) + def _workflow_parallel_branch_start_to_stream_response( self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: @@ -747,7 +855,7 @@ class WorkflowCycleManage: ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: + def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -760,9 +868,11 @@ class WorkflowCycleManage: # Remove None files = [file for file in files if file] # Flatten list - files = [file for sublist in files for file in sublist] + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] - return files + # Convert to tuple to match Sequence type + return tuple(flattened_files) def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ @@ -800,6 +910,8 @@ class WorkflowCycleManage: elif isinstance(value, File): return value.to_dict() + return None + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run @@ -809,7 +921,7 @@ class WorkflowCycleManage: workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() if not workflow_run: - raise Exception(f"Workflow run not found: {workflow_run_id}") + raise WorkflowRunNotFoundError(workflow_run_id) return workflow_run @@ -822,7 +934,7 @@ class WorkflowCycleManage: workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) if not workflow_node_execution: - raise Exception(f"Workflow node execution not found: {node_execution_id}") + raise WorkflowNodeExecutionNotFoundError(node_execution_id) return workflow_node_execution diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 1481578630..8f8aaa93d6 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -40,17 +40,18 @@ class DatasetIndexToolCallbackHandler: def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 9ed5528e43..5017835565 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel): label: I18nObject icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] = [] class DefaultModelEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index f764c43098..eed2d7e49a 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -40,7 +40,7 @@ from models.provider import ( logger = logging.getLogger(__name__) -original_provider_configurate_methods = {} +original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): @@ -126,7 +126,7 @@ class ProviderConfiguration(BaseModel): return credentials - def get_system_configuration_status(self) -> SystemConfigurationStatus: + def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: """ Get system configuration status. :return: @@ -138,6 +138,8 @@ class ProviderConfiguration(BaseModel): current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) + if current_quota_configuration is None: + return None if not current_quota_configuration: return SystemConfigurationStatus.UNSUPPORTED @@ -155,7 +157,7 @@ class ProviderConfiguration(BaseModel): """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: + def get_custom_credentials(self, obfuscated: bool = False): """ Get custom credentials. @@ -746,7 +748,7 @@ class ProviderConfiguration(BaseModel): model_types = provider_schema.supported_model_types # Group model settings by model type and model - model_setting_map = defaultdict(dict) + model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting @@ -844,44 +846,45 @@ class ProviderConfiguration(BaseModel): ) except Exception as ex: logger.warning(f"get custom model schema failed, {ex}") - continue - if not custom_model_schema: - continue + if not custom_model_schema: + continue - if custom_model_schema.model_type not in model_types: - continue + if custom_model_schema.model_type not in model_types: + continue - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] - if model_setting.enabled is False: - status = ModelStatus.DISABLED + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][ + custom_model_schema.model + ] + if model_setting.enabled is False: + status = ModelStatus.DISABLED - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) ) - ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] - for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_model_names: - m.status = ModelStatus.NO_PERMISSION + for model in provider_models: + if model.model_type == ModelType.LLM and m.model not in restrict_model_names: + model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: - m.status = ModelStatus.QUOTA_EXCEEDED + model.status = ModelStatus.QUOTA_EXCEEDED return provider_models diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 3b186476eb..ad921bc255 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,7 +1,7 @@ from typing import Optional -class LLMError(Exception): +class LLMError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None @@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError): description = "Bad Request" -class ProviderTokenNotInitError(Exception): +class ProviderTokenNotInitError(ValueError): """ Custom exception raised when the provider token is not initialized. """ @@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception): self.description = args[0] if args else self.description -class QuotaExceededError(Exception): +class QuotaExceededError(ValueError): """ Custom exception raised when the quota for a provider has been exceeded. """ @@ -35,7 +35,7 @@ class QuotaExceededError(Exception): description = "Quota Exceeded" -class AppInvokeQuotaExceededError(Exception): +class AppInvokeQuotaExceededError(ValueError): """ Custom exception raised when the quota for an app has been exceeded. """ @@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception): description = "App Invoke Quota Exceeded" -class ModelCurrentlyNotSupportError(Exception): +class ModelCurrentlyNotSupportError(ValueError): """ Custom exception raised when the model not support """ @@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception): description = "Model Currently Not Support" -class InvokeRateLimitError(Exception): +class InvokeRateLimitError(ValueError): """Raised when the Invoke returns rate limit error.""" description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 38cebb6b6b..3f4e20ec24 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,3 +1,5 @@ +from typing import cast + import requests from configs import dify_config @@ -5,7 +7,7 @@ from models.api_based_extension import APIBasedExtensionPoint class APIBasedExtensionRequestor: - timeout: (int, int) = (5, 60) + timeout: tuple[int, int] = (5, 60) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str) -> None: @@ -51,4 +53,4 @@ class APIBasedExtensionRequestor: "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) ) - return response.json() + return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 97dbaf2026..231743bf2a 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -38,8 +38,8 @@ class Extensible: @classmethod def scan_extensions(cls): - extensions: list[ModuleExtension] = [] - position_map = {} + extensions = [] + position_map: dict[str, int] = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") @@ -58,7 +58,8 @@ class Extensible: # is builtin extension, builtin extension # in the front-end page and business logic, there are special treatments. builtin = False - position = None + # default position is 0 can not be None for sort_to_dict_by_position_map + position = 0 if "__builtin__" in file_names: builtin = True @@ -89,7 +90,7 @@ class Extensible: logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") continue - json_data = {} + json_data: dict[str, Any] = {} if not builtin: if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 3da170455e..9eb9e0306b 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -1,4 +1,6 @@ -from core.extension.extensible import ExtensionModule, ModuleExtension +from typing import cast + +from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation @@ -10,7 +12,8 @@ class Extension: def init(self): for module, module_class in self.module_classes.items(): - self.__module_extensions[module.value] = module_class.scan_extensions() + m = cast(Extensible, module_class) + self.__module_extensions[module.value] = m.scan_extensions() def module_extensions(self, module: str) -> list[ModuleExtension]: module_extensions = self.__module_extensions.get(module) @@ -35,7 +38,8 @@ class Extension: def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) - return module_extension.extension_class + t: type = module_extension.extension_class + return t def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: module_extension = self.module_extension(module, extension_name) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 54ec97a493..9989c8a090 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -48,7 +48,10 @@ class ApiExternalDataTool(ExternalDataTool): :return: the tool query result """ # get params from config + if not self.config: + raise ValueError("config is required, config: {}".format(self.config)) api_based_extension_id = self.config.get("api_based_extension_id") + assert api_based_extension_id is not None, "api_based_extension_id is required" # get api_based_extension api_based_extension = ( diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 84b94e117f..6a9703a569 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ -import concurrent import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import Any, Optional from flask import Flask, current_app @@ -17,9 +17,9 @@ class ExternalDataFetch: tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. @@ -30,13 +30,14 @@ class ExternalDataFetch: :param query: the query :return: the filled inputs """ - results = {} + results: dict[str, Any] = {} + inputs = dict(inputs) with ThreadPoolExecutor() as executor: futures = {} for tool in external_data_tools: - future = executor.submit( + future: Future[tuple[str | None, str | None]] = executor.submit( self._query_external_data_tool, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore tenant_id, app_id, tool, @@ -46,9 +47,10 @@ class ExternalDataFetch: futures[future] = tool - for future in concurrent.futures.as_completed(futures): + for future in as_completed(futures): tool_variable, result = future.result() - results[tool_variable] = result + if tool_variable is not None: + results[tool_variable] = result inputs.update(results) return inputs @@ -59,7 +61,7 @@ class ExternalDataFetch: tenant_id: str, app_id: str, external_data_tool: ExternalDataVariableEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, ) -> tuple[Optional[str], Optional[str]]: """ diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 2872109859..245507e17c 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,4 +1,5 @@ -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -23,9 +24,10 @@ class ExternalDataToolFactory: """ code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - extension_class.validate_config(tenant_id, config) + # FIXME mypy issue here, figure out how to fix it + extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: """ Query the external data tool. @@ -33,4 +35,4 @@ class ExternalDataToolFactory: :param query: the query of chat app :return: the tool query result """ - return self.__extension_instance.query(inputs, query) + return cast(str, self.__extension_instance.query(inputs, query)) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 3b83683755..4a50fb85c9 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,15 +1,15 @@ import base64 +from collections.abc import Mapping from configs import dify_config -from core.file import file_repository from core.helper import ssrf_proxy from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + MultiModalPromptMessageContent, VideoPromptMessageContent, ) -from extensions.ext_database import db from extensions.ext_storage import storage from . import helpers @@ -41,53 +41,42 @@ def to_prompt_message_content( /, *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -): - match f.type: - case FileType.IMAGE: - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": - data = _to_url(f) - else: - data = _to_base64_data_string(f) +) -> MultiModalPromptMessageContent: + if f.extension is None: + raise ValueError("Missing file extension") + if f.mime_type is None: + raise ValueError("Missing file mime_type") - return ImagePromptMessageContent(data=data, detail=image_detail_config) - case FileType.AUDIO: - encoded_string = _get_encoded_string(f) - if f.extension is None: - raise ValueError("Missing file extension") - return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) - case FileType.VIDEO: - if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": - data = _to_url(f) - else: - data = _to_base64_data_string(f) - if f.extension is None: - raise ValueError("Missing file extension") - return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) - case FileType.DOCUMENT: - data = _get_encoded_string(f) - if f.mime_type is None: - raise ValueError("Missing file mime_type") - return DocumentPromptMessageContent( - encode_format="base64", - mime_type=f.mime_type, - data=data, - ) - case _: - raise ValueError(f"file type {f.type} is not supported") + params = { + "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", + "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", + "format": f.extension.removeprefix("."), + "mime_type": f.mime_type, + } + if f.type == FileType.IMAGE: + params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { + FileType.IMAGE: ImagePromptMessageContent, + FileType.AUDIO: AudioPromptMessageContent, + FileType.VIDEO: VideoPromptMessageContent, + FileType.DOCUMENT: DocumentPromptMessageContent, + } + + try: + return prompt_class_map[f.type].model_validate(params) + except KeyError: + raise ValueError(f"file type {f.type} is not supported") def download(f: File, /): - if f.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file = file_repository.get_tool_file(session=db.session(), file=f) - return _download_file_content(tool_file.file_key) - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - return _download_file_content(upload_file.key) - # remote file - response = ssrf_proxy.get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content + if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): + return _download_file_content(f._storage_key) + elif f.transfer_method == FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + return response.content + raise ValueError(f"unsupported transfer method: {f.transfer_method}") def _download_file_content(path: str, /): @@ -118,21 +107,14 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - upload_file = file_repository.get_upload_file(session=db.session(), file=f) - data = _download_file_content(upload_file.key) + data = _download_file_content(f._storage_key) case FileTransferMethod.TOOL_FILE: - tool_file = file_repository.get_tool_file(session=db.session(), file=f) - data = _download_file_content(tool_file.file_key) + data = _download_file_content(f._storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string -def _to_base64_data_string(f: File, /): - encoded_string = _get_encoded_string(f) - return f"data:{f.mime_type};base64,{encoded_string}" - - def _to_url(f: File, /): if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py deleted file mode 100644 index 975e1e72db..0000000000 --- a/api/core/file/file_repository.py +++ /dev/null @@ -1,32 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - -from models import ToolFile, UploadFile - -from .models import File - - -def get_upload_file(*, session: Session, file: File): - if file.related_id is None: - raise ValueError("Missing file related_id") - stmt = select(UploadFile).filter( - UploadFile.id == file.related_id, - UploadFile.tenant_id == file.tenant_id, - ) - record = session.scalar(stmt) - if not record: - raise ValueError(f"upload file {file.related_id} not found") - return record - - -def get_tool_file(*, session: Session, file: File): - if file.related_id is None: - raise ValueError("Missing file related_id") - stmt = select(ToolFile).filter( - ToolFile.id == file.related_id, - ToolFile.tenant_id == file.tenant_id, - ) - record = session.scalar(stmt) - if not record: - raise ValueError(f"tool file {file.related_id} not found") - return record diff --git a/api/core/file/models.py b/api/core/file/models.py index 85eb4a4823..b79acbcfa4 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -47,6 +47,38 @@ class File(BaseModel): mime_type: Optional[str] = None size: int = -1 + # Those properties are private, should not be exposed to the outside. + _storage_key: str + + def __init__( + self, + *, + id: Optional[str] = None, + tenant_id: str, + type: FileType, + transfer_method: FileTransferMethod, + remote_url: Optional[str] = None, + related_id: Optional[str] = None, + filename: Optional[str] = None, + extension: Optional[str] = None, + mime_type: Optional[str] = None, + size: int = -1, + storage_key: str, + ): + super().__init__( + id=id, + tenant_id=tenant_id, + type=type, + transfer_method=transfer_method, + remote_url=remote_url, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + ) + self._storage_key = storage_key + def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index a17b7be367..6fa101cf36 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager @@ -9,4 +9,4 @@ tool_file_manager: dict[str, Any] = {"manager": None} class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return tool_file_manager["manager"] + return cast("ToolFileManager", tool_file_manager["manager"]) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 011ff382ea..15b501780e 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -38,7 +38,7 @@ class CodeLanguage(StrEnum): class CodeExecutor: - dependencies_cache = {} + dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { @@ -103,22 +103,22 @@ class CodeExecutor: ) try: - response = response.json() + response_data = response.json() except: raise CodeExecutionError("Failed to parse response") - if (code := response.get("code")) != 0: - raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response_data.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response = CodeExecutionResponse(**response) + response_code = CodeExecutionResponse(**response_data) - if response.data.error: - raise CodeExecutionError(response.data.error) + if response_code.data.error: + raise CodeExecutionError(response_code.data.error) - return response.data.stdout or "" + return response_code.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): """ Execute code :param language: code language diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index db2eb5ebb6..264947b568 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -1,9 +1,11 @@ +from collections.abc import Mapping + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: dict) -> str: + def format(cls, template: str, inputs: Mapping[str, str]) -> str: """ Format template :param template: template @@ -11,5 +13,4 @@ class Jinja2Formatter: :return: """ result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - - return result["result"] + return str(result.get("result", "")) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b7a07b21e1..baa792b5bc 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -25,21 +25,28 @@ class TemplateTransformer(ABC): return runner_script, preload_script @classmethod - def extract_result_str_from_response(cls, response: str) -> str: + def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") - result = result.group(1) - return result + return result.group(1) @classmethod - def transform_response(cls, response: str) -> dict: + def transform_response(cls, response: str) -> Mapping[str, Any]: """ Transform response to dict :param response: response :return: """ - return json.loads(cls.extract_result_str_from_response(response)) + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") + if not isinstance(result, dict): + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") + return result @classmethod @abstractmethod diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 96341a1b78..744fce1cf9 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -1,6 +1,5 @@ import base64 -from extensions.ext_database import db from libs import rsa @@ -14,6 +13,7 @@ def obfuscated_token(token: str): def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + from models.engine import db if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): raise ValueError(f"Tenant with id {tenant_id} not found") diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py index 518962c165..81501d2e4e 100644 --- a/api/core/helper/lru_cache.py +++ b/api/core/helper/lru_cache.py @@ -4,7 +4,7 @@ from typing import Any class LRUCache: def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache: OrderedDict[Any, Any] = OrderedDict() self.capacity = capacity def get(self, key: Any) -> Any: diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 5e274f8916..35349210bd 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -30,7 +30,7 @@ class ProviderCredentialsCache: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 899d2b6994..6a5982eca4 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -54,7 +54,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt if moderation_result is True: return True - except Exception as ex: + except Exception: logger.exception(f"Fails to check moderation, provider_name: {provider_name}") raise InvokeBadRequestError("Rate limit exceeded, please try again later.") diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 1e2fefce88..9a041667e4 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz if existed_spec: spec = existed_spec if not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - spec = importlib.util.spec_from_file_location(module_name, py_file_path) + # FIXME: mypy does not support the type of spec.loader + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore if not spec or not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports spec.loader = importlib.util.LazyLoader(spec.loader) @@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'") + logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") raise e @@ -57,6 +58,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}") case _: - raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}") diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ef4516b404..424983a819 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] -class MaxRetriesExceededError(Exception): +class MaxRetriesExceededError(ValueError): """Raised when the maximum number of retries is exceeded.""" pass @@ -65,11 +65,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): except httpx.RequestError as e: logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + if max_retries == 0: + raise retries += 1 if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) - raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index e848b46c56..3b67b3f848 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -33,7 +33,7 @@ class ToolParameterCache: except JSONDecodeError: return None - return cached_tool_parameter + return dict(cached_tool_parameter) else: return None diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 2cc5d89727..2e4a04c579 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -29,7 +29,7 @@ class ToolProviderCredentialsCache: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 7be0187ad3..67fd355ee9 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -67,7 +67,7 @@ class HostingConfiguration: "base_model_name": "gpt-35-turbo", } - quotas = [] + quotas: list[HostingQuota] = [] hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, @@ -123,7 +123,7 @@ class HostingConfiguration: def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT @@ -157,7 +157,7 @@ class HostingConfiguration: @staticmethod def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT @@ -187,7 +187,7 @@ class HostingConfiguration: def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_MINIMAX_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -205,7 +205,7 @@ class HostingConfiguration: def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_SPARK_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -223,7 +223,7 @@ class HostingConfiguration: def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_ZHIPUAI_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 3bcd4c2a4d..685dbc8ed4 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -6,10 +6,10 @@ import re import threading import time import uuid -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -62,6 +62,8 @@ class IndexingRunner: .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract @@ -120,6 +122,8 @@ class IndexingRunner: .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -254,7 +258,7 @@ class IndexingRunner: tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts: list[str] = [] total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -285,7 +289,8 @@ class IndexingRunner: for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() try: - storage.delete(image_file.key) + if image_file: + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ @@ -379,8 +384,9 @@ class IndexingRunner: # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata["document_id"] = dataset_document.id - text_doc.metadata["dataset_id"] = dataset_document.dataset_id + if text_doc.metadata is not None: + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @@ -400,6 +406,7 @@ class IndexingRunner: """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule.mode == "custom": # The user-defined segmentation rule rules = json.loads(processing_rule.rules) @@ -426,9 +433,10 @@ class IndexingRunner: ) else: # Automatic segmentation + automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"]) character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], + chunk_size=automatic_rules["max_tokens"], + chunk_overlap=automatic_rules["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, ) @@ -497,8 +505,8 @@ class IndexingRunner: """ Split the text documents into nodes. """ - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -509,10 +517,11 @@ class IndexingRunner: split_documents = [] for document_node in documents: if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -529,7 +538,7 @@ class IndexingRunner: document_format_thread = threading.Thread( target=self.format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": tenant_id, "document_node": doc, "all_qa_documents": all_qa_documents, @@ -557,11 +566,12 @@ class IndexingRunner: qa_document = Document( page_content=result["question"], metadata=document_node.metadata.model_copy() ) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -575,7 +585,7 @@ class IndexingRunner: """ Split the text documents into nodes. """ - all_documents = [] + all_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -588,11 +598,11 @@ class IndexingRunner: for document in documents: if document.page_content is None or not document.page_content.strip(): continue - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document.page_content) - - document.metadata["doc_id"] = doc_id - document.metadata["doc_hash"] = hash + if document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -648,7 +658,7 @@ class IndexingRunner: # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() if dataset.indexing_technique == "high_quality": @@ -659,7 +669,7 @@ class IndexingRunner: futures.append( executor.submit( self._process_chunk, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore index_processor, chunk_documents, dataset, diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3a92c8d9d2..9fe3f68f2a 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Optional +from typing import Optional, cast from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser @@ -13,6 +13,7 @@ from core.llm_generator.prompts import ( WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -44,10 +45,13 @@ class LLMGenerator: prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response = model_instance.invoke_llm( - prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + ), ) - answer = response.message.content + answer = cast(str, response.message.content) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) if cleaned_answer is None: return "" @@ -94,11 +98,16 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt)] try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"max_tokens": 256, "temperature": 0}, + stream=False, + ), ) - questions = output_parser.parse(response.message.content) + questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] except Exception as e: @@ -138,11 +147,14 @@ class LLMGenerator: ) try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["prompt"] = response.message.content + rule_config["prompt"] = cast(str, response.message.content) except InvokeError as e: error = str(e) @@ -178,15 +190,18 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: try: # the first step to generate the task prompt - prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) except InvokeError as e: error = str(e) @@ -195,8 +210,10 @@ class LLMGenerator: return rule_config - rule_config["prompt"] = prompt_content.message.content + rule_config["prompt"] = cast(str, prompt_content.message.content) + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( inputs={ "INPUT_TEXT": prompt_content.message.content, @@ -216,19 +233,25 @@ class LLMGenerator: statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: - parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + parameter_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) except InvokeError as e: error = str(e) error_step = "generate variables" try: - statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + statement_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["opening_statement"] = statement_content.message.content + rule_config["opening_statement"] = cast(str, statement_content.message.content) except InvokeError as e: error = str(e) error_step = "generate conversation opener" @@ -267,19 +290,22 @@ class LLMGenerator: model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [UserPromptMessage(content=prompt)] model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - generated_code = response.message.content + generated_code = cast(str, response.message.content) return {"code": generated_code, "language": code_language, "error": ""} except InvokeError as e: @@ -303,9 +329,14 @@ class LLMGenerator: prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ), ) - answer = response.message.content + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 1e743f1757..0922806ca8 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserError(Exception): +class OutputParserError(ValueError): pass diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 81d08dc885..003a0c85b1 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -68,7 +68,7 @@ class TokenBufferMemory: messages = list(reversed(thread_messages)) - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 5956ea1ae9..a64abf1e2a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -160,17 +160,20 @@ class ModelInstance: raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, + return cast( + Union[LLMResult, Generator], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ), ) def get_llm_num_tokens( @@ -187,12 +190,15 @@ class ModelInstance: raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools, + ), ) def invoke_text_embedding( @@ -210,13 +216,16 @@ class ModelInstance: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - texts=texts, - user=user, - input_type=input_type, + return cast( + TextEmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + texts=texts, + user=user, + input_type=input_type, + ), ) def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: @@ -230,11 +239,14 @@ class ModelInstance: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - texts=texts, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts, + ), ) def invoke_rerank( @@ -259,15 +271,18 @@ class ModelInstance: raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user, + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), ) def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: @@ -282,12 +297,15 @@ class ModelInstance: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - text=text, - user=user, + return cast( + bool, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + text=text, + user=user, + ), ) def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: @@ -302,12 +320,15 @@ class ModelInstance: raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - file=file, - user=user, + return cast( + str, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + file=file, + user=user, + ), ) def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: @@ -324,17 +345,20 @@ class ModelInstance: raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - content_text=content_text, - user=user, - tenant_id=tenant_id, - voice=voice, + return cast( + Iterable[bytes], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + content_text=content_text, + user=user, + tenant_id=tenant_id, + voice=voice, + ), ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: """ Round-robin invoke :param function: function to invoke @@ -345,7 +369,7 @@ class ModelInstance: if not self.load_balancing_manager: return function(*args, **kwargs) - last_exception = None + last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None while True: lb_config = self.load_balancing_manager.fetch_next() if not lb_config: @@ -499,7 +523,7 @@ class LBModelManager: if real_index > max_index: real_index = 0 - config = self._load_balancing_configs[real_index] + config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index] if self.in_cooldown(config): cooldown_load_balancing_configs.append(config) @@ -543,8 +567,7 @@ class LBModelManager: self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - res = redis_client.exists(cooldown_cache_key) - res = cast(bool, res) + res: bool = redis_client.exists(cooldown_cache_key) return res @staticmethod diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 3b6b825244..1f21a2d376 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,8 @@ import json import logging import sys -from typing import Optional +from collections.abc import Sequence +from typing import Optional, cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,7 +21,7 @@ class LoggingCallback(Callback): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -76,7 +77,7 @@ class LoggingCallback(Callback): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -94,7 +95,7 @@ class LoggingCallback(Callback): :param stream: is stream response :param user: unique user id """ - sys.stdout.write(chunk.delta.message.content) + sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() def on_after_invoke( @@ -106,7 +107,7 @@ class LoggingCallback(Callback): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -147,7 +148,7 @@ class LoggingCallback(Callback): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index 1c73755cff..c3e1351e3b 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -4,6 +4,7 @@ from .message_entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + MultiModalPromptMessageContent, PromptMessage, PromptMessageContent, PromptMessageContentType, @@ -27,6 +28,7 @@ __all__ = [ "LLMResultChunkDelta", "LLMUsage", "ModelPropertyKey", + "MultiModalPromptMessageContent", "PromptMessage", "PromptMessage", "PromptMessageContent", diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index f2870209bb..2f682ceef5 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,7 +1,7 @@ from abc import ABC from collections.abc import Sequence from enum import Enum, StrEnum -from typing import Literal, Optional +from typing import Optional from pydantic import BaseModel, Field, field_validator @@ -67,7 +67,6 @@ class PromptMessageContent(BaseModel): """ type: PromptMessageContentType - data: str class TextPromptMessageContent(PromptMessageContent): @@ -76,21 +75,34 @@ class TextPromptMessageContent(PromptMessageContent): """ type: PromptMessageContentType = PromptMessageContentType.TEXT + data: str -class VideoPromptMessageContent(PromptMessageContent): +class MultiModalPromptMessageContent(PromptMessageContent): + """ + Model class for multi-modal prompt message content. + """ + + type: PromptMessageContentType + format: str = Field(default=..., description="the format of multi-modal file") + base64_data: str = Field(default="", description="the base64 data of multi-modal file") + url: str = Field(default="", description="the url of multi-modal file") + mime_type: str = Field(default=..., description="the mime type of multi-modal file") + + @property + def data(self): + return self.url or f"data:{self.mime_type};base64,{self.base64_data}" + + +class VideoPromptMessageContent(MultiModalPromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.VIDEO - data: str = Field(..., description="Base64 encoded video data") - format: str = Field(..., description="Video format") -class AudioPromptMessageContent(PromptMessageContent): +class AudioPromptMessageContent(MultiModalPromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.AUDIO - data: str = Field(..., description="Base64 encoded audio data") - format: str = Field(..., description="Audio format") -class ImagePromptMessageContent(PromptMessageContent): +class ImagePromptMessageContent(MultiModalPromptMessageContent): """ Model class for image prompt message content. """ @@ -103,11 +115,8 @@ class ImagePromptMessageContent(PromptMessageContent): detail: DETAIL = DETAIL.LOW -class DocumentPromptMessageContent(PromptMessageContent): +class DocumentPromptMessageContent(MultiModalPromptMessageContent): type: PromptMessageContentType = PromptMessageContentType.DOCUMENT - encode_format: Literal["base64"] - mime_type: str - data: str class PromptMessage(ABC, BaseModel): diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index edfb19c7d0..7675425361 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,7 +1,7 @@ from typing import Optional -class InvokeError(Exception): +class InvokeError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 7fcd2133f9..16bebcc67d 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -1,4 +1,4 @@ -class CredentialsValidateFailedError(Exception): +class CredentialsValidateFailedError(ValueError): """ Credentials validate failed error """ diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index a044a948aa..533bcb2878 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,4 +1,5 @@ import decimal +from abc import abstractmethod from typing import Optional from pydantic import BaseModel, ConfigDict, Field @@ -37,6 +38,17 @@ class AIModel(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + @abstractmethod + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + raise NotImplementedError + @property def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 5447058ad7..c93ab4f61e 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -41,7 +41,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 4ff1c9032a..c4c1f92177 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -83,7 +83,8 @@ class TextEmbeddingModel(AIModel): model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + return content_size return 1000 @@ -98,6 +99,7 @@ class TextEmbeddingModel(AIModel): model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-exp.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-exp.yaml deleted file mode 100644 index bcd59623a7..0000000000 --- a/api/core/model_runtime/model_providers/google/llm/gemini-2.0-flash-exp.yaml +++ /dev/null @@ -1,39 +0,0 @@ -model: gemini-2.0-flash-exp -label: - en_US: Gemini 2.0 Flash Exp -model_type: llm -features: - - agent-thought - - vision - - tool-call - - stream-tool-call - - document -model_properties: - mode: chat - context_size: 1048576 -parameter_rules: - - name: temperature - use_template: temperature - - name: top_p - use_template: top_p - - name: top_k - label: - zh_Hans: 取样数量 - en_US: Top k - type: int - help: - zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 - en_US: Only sample from the top K options for each subsequent token. - required: false - - name: max_output_tokens - use_template: max_tokens - default: 8192 - min: 1 - max: 8192 - - name: json_schema - use_template: json_schema -pricing: - input: '0.00' - output: '0.00' - unit: '0.000001' - currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1206.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1206.yaml deleted file mode 100644 index 1743d8b968..0000000000 --- a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1206.yaml +++ /dev/null @@ -1,38 +0,0 @@ -model: gemini-exp-1206 -label: - en_US: Gemini exp 1206 -model_type: llm -features: - - agent-thought - - vision - - tool-call - - stream-tool-call -model_properties: - mode: chat - context_size: 2097152 -parameter_rules: - - name: temperature - use_template: temperature - - name: top_p - use_template: top_p - - name: top_k - label: - zh_Hans: 取样数量 - en_US: Top k - type: int - help: - zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 - en_US: Only sample from the top K options for each subsequent token. - required: false - - name: max_output_tokens - use_template: max_tokens - default: 8192 - min: 1 - max: 8192 - - name: json_schema - use_template: json_schema -pricing: - input: '0.00' - output: '0.00' - unit: '0.000001' - currency: USD diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index a83248c0c2..9bf055ce6b 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -103,7 +103,8 @@ class OpenAIModerationModel(ModerationModel): model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + return max_characters_per_chunk return 2000 @@ -118,7 +119,8 @@ class OpenAIModerationModel(ModerationModel): model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-2.0-flash-exp.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-2.0-flash-exp.yaml deleted file mode 100644 index bcd59623a7..0000000000 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-2.0-flash-exp.yaml +++ /dev/null @@ -1,39 +0,0 @@ -model: gemini-2.0-flash-exp -label: - en_US: Gemini 2.0 Flash Exp -model_type: llm -features: - - agent-thought - - vision - - tool-call - - stream-tool-call - - document -model_properties: - mode: chat - context_size: 1048576 -parameter_rules: - - name: temperature - use_template: temperature - - name: top_p - use_template: top_p - - name: top_k - label: - zh_Hans: 取样数量 - en_US: Top k - type: int - help: - zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 - en_US: Only sample from the top K options for each subsequent token. - required: false - - name: max_output_tokens - use_template: max_tokens - default: 8192 - min: 1 - max: 8192 - - name: json_schema - use_template: json_schema -pricing: - input: '0.00' - output: '0.00' - unit: '0.000001' - currency: USD diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 029ec1a581..8cc8adfc36 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Union, cast from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType @@ -38,7 +38,7 @@ class CommonValidator: def _validate_credential_form_schema( self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Optional[str]: + ) -> Union[str, bool, None]: """ Validate credential form schema @@ -47,6 +47,7 @@ class CommonValidator: :return: validated credential form schema value """ # If the variable does not exist in credentials + value: Union[str, bool, None] = None if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: @@ -61,7 +62,7 @@ class CommonValidator: return None # Get the value corresponding to the variable from credentials - value = credentials[credential_form_schema.variable] + value = cast(str, credentials[credential_form_schema.variable]) # If max_length=0, no validation is performed if credential_form_schema.max_length: diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index ec1bad5698..03e3506271 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -129,7 +129,8 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - obj_dict = dataclasses.asdict(obj) + # FIXME: mypy error, try to fix it instead of using type: ignore + obj_dict = dataclasses.asdict(obj) # type: ignore return jsonable_encoder( obj_dict, by_alias=by_alias, diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 2067092d80..5e8a723ec7 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -4,6 +4,7 @@ from pydantic import BaseModel def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): - return pydantic.model_dump(model) + # FIXME mypy error, try to fix it instead of using type: ignore + return pydantic.model_dump(model) # type: ignore else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 094ad78636..c65a3885fd 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor @@ -43,6 +45,8 @@ class ApiModeration(Moderation): def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) @@ -57,6 +61,8 @@ class ApiModeration(Moderation): def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: params = ModerationOutputParams(app_id=self.app_id, text=text) @@ -69,14 +75,18 @@ class ApiModeration(Moderation): ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: - extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + if self.config is None: + raise ValueError("The config is not set.") + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) + if not extension: + raise ValueError("API-based Extension not found. Please check it again.") requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) result = requestor.request(extension_point, params) return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 60898d5547..d8c392d097 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -100,14 +100,14 @@ class Moderation(Extensible, ABC): if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response")) > 100: + if len(inputs_config.get("preset_response", 0)) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response")) > 100: + if len(outputs_config.get("preset_response", 0)) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 96bf2ab54b..0ad4438c14 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -22,7 +22,8 @@ class ModerationFactory: """ code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - extension_class.validate_config(tenant_id, config) + # FIXME: mypy error, try to fix it instead of using type: ignore + extension_class.validate_config(tenant_id, config) # type: ignore def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46d3963bd0..3ac33966cb 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -17,11 +18,11 @@ class InputModeration: app_id: str, tenant_id: str, app_config: AppConfig, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -33,6 +34,7 @@ class InputModeration: :param trace_manager: trace manager :return: """ + inputs = dict(inputs) if not app_config.sensitive_word_avoidance: return False, inputs, query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 00b3c56c03..9dd2665c3b 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -21,7 +21,7 @@ class KeywordsModeration(Moderation): if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords")) > 10000: + if len(config.get("keywords", [])) > 10000: raise ValueError("keywords length must be less than 10000") keywords_row_len = config["keywords"].split("\n") @@ -31,6 +31,8 @@ class KeywordsModeration(Moderation): def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -50,6 +52,8 @@ class KeywordsModeration(Moderation): def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: # Filter out empty values diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6465de23b9..d64f17b383 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -20,6 +20,8 @@ class OpenAIModeration(Moderation): def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -35,6 +37,8 @@ class OpenAIModeration(Moderation): def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: flagged = self._is_violated({"text": text}) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 4635bd9c25..e595be126c 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -70,7 +70,7 @@ class OutputModeration(BaseModel): thread = threading.Thread( target=self.worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, }, ) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 71ff03b6ef..f0e34c0cd7 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_id: str workflow_run_elapsed_time: Union[int, float] workflow_run_status: str - workflow_run_inputs: dict[str, Any] - workflow_run_outputs: dict[str, Any] + workflow_run_inputs: Mapping[str, Any] + workflow_run_outputs: Mapping[str, Any] workflow_run_version: str error: Optional[str] = None total_tokens: int diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 29fdebd8fe..b9ba068b19 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -77,8 +77,8 @@ class LangFuseDataTrace(BaseTraceInstance): id=trace_id, user_id=user_id, name=name, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], @@ -87,8 +87,8 @@ class LangFuseDataTrace(BaseTraceInstance): workflow_span_data = LangfuseSpan( id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -102,8 +102,8 @@ class LangFuseDataTrace(BaseTraceInstance): id=trace_id, user_id=user_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["workflow"], diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 99221d669b..348b7ba501 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 672843e5a8..4ffd888bdd 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -3,6 +3,7 @@ import logging import os import uuid from datetime import datetime, timedelta +from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase @@ -63,6 +64,8 @@ class LangSmithDataTrace(BaseTraceInstance): def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_id = trace_info.message_id or trace_info.workflow_run_id + if trace_info.start_time is None: + trace_info.start_time = datetime.now() message_dotted_order = ( generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None ) @@ -78,8 +81,8 @@ class LangSmithDataTrace(BaseTraceInstance): message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE.value, - inputs=trace_info.workflow_run_inputs, - outputs=trace_info.workflow_run_outputs, + inputs=dict(trace_info.workflow_run_inputs), + outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -90,6 +93,15 @@ class LangSmithDataTrace(BaseTraceInstance): error=trace_info.error, trace_id=trace_id, dotted_order=message_dotted_order, + file_list=[], + serialized=None, + parent_run_id=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(message_run) @@ -98,11 +110,11 @@ class LangSmithDataTrace(BaseTraceInstance): total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - inputs=trace_info.workflow_run_inputs, + inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, end_time=trace_info.workflow_data.finished_at, - outputs=trace_info.workflow_run_outputs, + outputs=dict(trace_info.workflow_run_outputs), extra={ "metadata": metadata, }, @@ -111,6 +123,13 @@ class LangSmithDataTrace(BaseTraceInstance): parent_run_id=trace_info.message_id or None, trace_id=trace_id, dotted_order=workflow_dotted_order, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) @@ -211,25 +230,35 @@ class LangSmithDataTrace(BaseTraceInstance): id=node_execution_id, trace_id=trace_id, dotted_order=node_dotted_order, + error="", + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) def message_trace(self, trace_info: MessageTraceInfo): # get message file data - file_list = trace_info.file_list - message_file_data: MessageFile = trace_info.message_file_data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata message_data = trace_info.message_data + if message_data is None: + return message_id = message_data.id user_id = message_data.from_account_id metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = ( + end_user_data: Optional[EndUser] = ( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -247,12 +276,20 @@ class LangSmithDataTrace(BaseTraceInstance): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, tags=["message", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + parent_run_id=None, ) self.add_run(message_run) @@ -267,17 +304,27 @@ class LangSmithDataTrace(BaseTraceInstance): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, parent_run_id=message_id, tags=["llm", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + id=str(uuid.uuid4()), ) self.add_run(llm_run) def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return langsmith_run = LangSmithRunModel( name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, @@ -288,48 +335,82 @@ class LangSmithDataTrace(BaseTraceInstance): "inputs": trace_info.inputs, }, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["moderation"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(langsmith_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data + if message_data is None: + return suggested_question_run = LangSmithRunModel( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["suggested_question"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or message_data.created_at, end_time=trace_info.end_time or message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return dataset_retrieval_run = LangSmithRunModel( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["dataset_retrieval"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(dataset_retrieval_run) @@ -347,7 +428,18 @@ class LangSmithDataTrace(BaseTraceInstance): parent_run_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, - file_list=[trace_info.file_url], + file_list=[cast(str, trace_info.file_url)], + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error=trace_info.error or "", ) self.add_run(tool_run) @@ -358,12 +450,23 @@ class LangSmithDataTrace(BaseTraceInstance): inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["generate_name"], start_time=trace_info.start_time or datetime.now(), end_time=trace_info.end_time or datetime.now(), + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], + parent_run_id=None, ) self.add_run(name_run) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index b7799ce1fb..f538eaef5b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -33,11 +33,11 @@ from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage -from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig +from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -provider_config_map = { +provider_config_map: dict[str, dict[str, Any]] = { TracingProviderEnum.LANGFUSE.value: { "config_class": LangfuseConfig, "secret_keys": ["public_key", "secret_key"], @@ -145,7 +145,7 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -155,7 +155,11 @@ class OpsTraceManager: return None # decrypt_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") + + tenant_id = app.tenant_id decrypt_tracing_config = cls.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -178,7 +182,7 @@ class OpsTraceManager: if app_id is None: return None - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if app is None: return None @@ -209,8 +213,12 @@ class OpsTraceManager: def get_app_config_through_message_id(cls, message_id: str): app_model_config = None message_data = db.session.query(Message).filter(Message.id == message_id).first() + if not message_data: + return None conversation_id = message_data.conversation_id conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + if not conversation_data: + return None if conversation_data.app_model_config_id: app_model_config = ( @@ -236,7 +244,9 @@ class OpsTraceManager: if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app_config: + raise ValueError("App not found") app_config.tracing = json.dumps( { "enabled": enabled, @@ -252,7 +262,9 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) @@ -355,7 +367,13 @@ class TraceTask: def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): + def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): + if not workflow_run: + raise ValueError("Workflow run not found") + + db.session.merge(workflow_run) + db.session.refresh(workflow_run) + workflow_id = workflow_run.workflow_id tenant_id = workflow_run.tenant_id workflow_run_id = workflow_run.id @@ -477,6 +495,8 @@ class TraceTask: def moderation_trace(self, message_id, timer, **kwargs): moderation_result = kwargs.get("moderation_result") + if not moderation_result: + return {} inputs = kwargs.get("inputs") message_data = get_message_data(message_id) if not message_data: @@ -512,7 +532,7 @@ class TraceTask: return moderation_trace_info def suggested_question_trace(self, message_id, timer, **kwargs): - suggested_question = kwargs.get("suggested_question") + suggested_question = kwargs.get("suggested_question", []) message_data = get_message_data(message_id) if not message_data: return {} @@ -580,7 +600,7 @@ class TraceTask: dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents], + documents=[doc.model_dump() for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -590,9 +610,9 @@ class TraceTask: return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get("tool_name") - tool_inputs = kwargs.get("tool_inputs") - tool_outputs = kwargs.get("tool_outputs") + tool_name = kwargs.get("tool_name", "") + tool_inputs = kwargs.get("tool_inputs", {}) + tool_outputs = kwargs.get("tool_outputs", {}) message_data = get_message_data(message_id) if not message_data: return {} @@ -602,7 +622,7 @@ class TraceTask: tool_parameters = {} created_time = message_data.created_at end_time = message_data.updated_at - agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts + agent_thoughts = message_data.agent_thoughts for agent_thought in agent_thoughts: if tool_name in agent_thought.tools: created_time = agent_thought.created_at @@ -666,6 +686,8 @@ class TraceTask: generate_conversation_name = kwargs.get("generate_conversation_name") inputs = kwargs.get("inputs") tenant_id = kwargs.get("tenant_id") + if not tenant_id: + return {} start_time = timer.get("start") end_time = timer.get("end") @@ -687,8 +709,8 @@ class TraceTask: return generate_name_trace_info -trace_manager_timer = None -trace_manager_queue = queue.Queue() +trace_manager_timer: Optional[threading.Timer] = None +trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) @@ -700,7 +722,7 @@ class TraceQueueManager: self.app_id = app_id self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - self.flask_app = current_app._get_current_object() + self.flask_app = current_app._get_current_object() # type: ignore if trace_manager_timer is None: self.start_timer() @@ -717,7 +739,7 @@ class TraceQueueManager: def collect_tasks(self): global trace_manager_queue - tasks = [] + tasks: list[TraceTask] = [] while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty(): task = trace_manager_queue.get_nowait() tasks.append(task) @@ -743,6 +765,8 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: + if task.app_id is None: + continue file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f3f824966..87c7a79fb0 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -39,7 +39,7 @@ class AdvancedPromptTransform(PromptTransform): self, *, prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, - inputs: dict[str, str], + inputs: Mapping[str, str], query: str, files: Sequence[File], context: Optional[str], @@ -77,7 +77,7 @@ class AdvancedPromptTransform(PromptTransform): def _get_completion_model_prompt_messages( self, prompt_template: CompletionModelPromptTemplate, - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -90,15 +90,15 @@ class AdvancedPromptTransform(PromptTransform): """ raw_prompt = prompt_template.text - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - if memory and memory_config: + if memory and memory_config and memory_config.role_prefix: role_prefix = memory_config.role_prefix prompt_inputs = self._set_histories_variable( memory=memory, @@ -135,7 +135,7 @@ class AdvancedPromptTransform(PromptTransform): def _get_chat_model_prompt_messages( self, prompt_template: list[ChatModelMessage], - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -146,7 +146,7 @@ class AdvancedPromptTransform(PromptTransform): """ Get chat model prompt messages. """ - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for prompt_item in prompt_template: raw_prompt = prompt_item.text @@ -160,7 +160,7 @@ class AdvancedPromptTransform(PromptTransform): prompt = vp.convert_template(raw_prompt).text else: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable( context=context, parser=parser, prompt_inputs=prompt_inputs ) @@ -207,7 +207,7 @@ class AdvancedPromptTransform(PromptTransform): last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append(file_manager.to_prompt_message_content(file)) @@ -229,7 +229,10 @@ class AdvancedPromptTransform(PromptTransform): return prompt_messages - def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_context_variable( + self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context @@ -238,7 +241,10 @@ class AdvancedPromptTransform(PromptTransform): return prompt_inputs - def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_query_variable( + self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query @@ -254,9 +260,10 @@ class AdvancedPromptTransform(PromptTransform): raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, - prompt_inputs: dict, + prompt_inputs: Mapping[str, str], model_config: ModelConfigWithCredentialsEntity, - ) -> dict: + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index caa1793ea8..09f017a7db 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -31,7 +31,7 @@ class AgentHistoryPromptTransform(PromptTransform): self.memory = memory def get_prompt(self) -> list[PromptMessage]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] num_system = 0 for prompt_message in self.history_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 87acdb3c49..1f040599be 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -42,7 +42,7 @@ class PromptTransform: ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -59,7 +59,7 @@ class PromptTransform: ai_prefix: Optional[str] = None, ) -> str: """Get memory messages.""" - kwargs = {"max_token_limit": max_token_limit} + kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} if human_prefix: kwargs["human_prefix"] = human_prefix @@ -76,11 +76,15 @@ class PromptTransform: self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int ) -> list[PromptMessage]: """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=memory_config.window.size - if ( - memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + return list( + memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if ( + memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0 + ) + else None, ) - else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 93dd92f188..e75877de9b 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,8 @@ import enum import json import os -from typing import TYPE_CHECKING, Optional +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -41,7 +42,7 @@ class ModelMode(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -prompt_file_contents = {} +prompt_file_contents: dict[str, Any] = {} class SimplePromptTransform(PromptTransform): @@ -53,9 +54,9 @@ class SimplePromptTransform(PromptTransform): self, app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, - inputs: dict, + inputs: Mapping[str, str], query: str, - files: list["File"], + files: Sequence["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -66,7 +67,7 @@ class SimplePromptTransform(PromptTransform): if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -77,7 +78,7 @@ class SimplePromptTransform(PromptTransform): else: prompt_messages, stops = self._get_completion_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -171,11 +172,11 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] # get prompt prompt, _ = self.get_prompt_str_and_rules( @@ -216,7 +217,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -263,7 +264,7 @@ class SimplePromptTransform(PromptTransform): return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) @@ -288,7 +289,7 @@ class SimplePromptTransform(PromptTransform): # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return prompt_file_contents[prompt_file_name] + return cast(dict, prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -301,7 +302,7 @@ class SimplePromptTransform(PromptTransform): # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return content + return cast(dict, content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index aa175153bc..2f4e651461 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from core.model_runtime.entities import ( AssistantPromptMessage, @@ -72,7 +72,7 @@ class PromptMessageUtil: } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) prompt = {"role": role, "text": text, "files": files} @@ -99,9 +99,9 @@ class PromptMessageUtil: } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) - params = { + params: dict[str, Any] = { "role": "user", "text": text, } diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 0fd08c5d3c..8e40674bc1 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}") WITH_VARIABLE_TMPL_REGEX = re.compile( @@ -28,7 +29,7 @@ class PromptTemplateParser: # Regular expression to match the template rules return re.findall(self.regex, self.template) - def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str: def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 534da4622b..a7e9bc809f 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional +from typing import Optional, cast from sqlalchemy.exc import IntegrityError @@ -16,6 +16,7 @@ from core.entities.provider_entities import ( ModelSettings, ProviderQuotaType, QuotaConfiguration, + QuotaUnit, SystemConfiguration, ) from core.helper import encrypter @@ -117,8 +118,8 @@ class ProviderManager: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), data=provider_entity, name_func=lambda x: x.provider, ): @@ -349,7 +350,7 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid is True).all() provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -368,7 +369,7 @@ class ProviderManager: # Get all provider model records of the workspace provider_models = ( db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid is True) .all() ) @@ -491,14 +492,16 @@ class ProviderManager: # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: - provider_record = Provider() - provider_record.tenant_id = tenant_id - provider_record.provider_name = provider_name - provider_record.provider_type = ProviderType.SYSTEM.value - provider_record.quota_type = ProviderQuotaType.TRIAL.value - provider_record.quota_limit = quota.quota_limit - provider_record.quota_used = 0 - provider_record.is_valid = True + # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic + provider_record = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=ProviderQuotaType.TRIAL.value, + quota_limit=quota.quota_limit, # type: ignore + quota_used=0, + is_valid=True, + ) db.session.add(provider_record) db.session.commit() except IntegrityError: @@ -589,7 +592,9 @@ class ProviderManager: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa + provider_credentials.get(variable) or "", # type: ignore + self.decoding_rsa_key, + self.decoding_cipher_rsa, ) except ValueError: pass @@ -671,13 +676,9 @@ class ProviderManager: # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if ( - provider_entity.provider not in hosting_configuration.provider_map - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled - ): - return SystemConfiguration(enabled=False) - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) + if provider_hosting_configuration is None or not provider_hosting_configuration.enabled: + return SystemConfiguration(enabled=False) # Convert provider_records to dict quota_type_to_provider_records_dict = {} @@ -688,14 +689,13 @@ class ProviderManager: quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) - quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=0, quota_limit=0, is_valid=False, @@ -708,7 +708,7 @@ class ProviderManager: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, is_valid=provider_record.quota_limit > provider_record.quota_used @@ -725,12 +725,12 @@ class ProviderManager: current_using_credentials = provider_hosting_configuration.credentials if current_quota_type == ProviderQuotaType.FREE: - provider_record = quota_type_to_provider_records_dict.get(current_quota_type) + provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type) - if provider_record: + if provider_record_quota_free: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, - identity_id=provider_record.id, + identity_id=provider_record_quota_free.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) @@ -763,7 +763,7 @@ class ProviderManager: except ValueError: pass - current_using_credentials = provider_credentials + current_using_credentials = provider_credentials or {} # cache provider credentials provider_credentials_cache.set(credentials=current_using_credentials) @@ -842,7 +842,7 @@ class ProviderManager: else [] ) - model_settings = [] + model_settings: list[ModelSettings] = [] if not provider_model_settings: return model_settings diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 992415657e..d17d76333e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -83,11 +83,15 @@ class DataPostProcessor: if reranking_model: try: model_manager = ModelManager() + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") + if not reranking_provider_name or not reranking_model_name: + return None rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], + provider=reranking_provider_name, model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], + model=reranking_model_name, ) return rerank_model_instance except InvokeAuthorizationError: diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a0153c1e58..95a2316f1d 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -32,8 +32,11 @@ class Jieba(BaseKeyword): keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) @@ -58,20 +61,26 @@ class Jieba(BaseKeyword): keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() + if keyword_table is None: + return False return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + if keyword_table is not None: + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) @@ -80,7 +89,7 @@ class Jieba(BaseKeyword): k = kwargs.get("top_k", 4) - sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) + sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] for chunk_index in sorted_chunk_indices: @@ -137,7 +146,7 @@ class Jieba(BaseKeyword): if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict["__data__"]["table"] + return dict(keyword_table_dict["__data__"]["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( @@ -188,8 +197,8 @@ class Jieba(BaseKeyword): # go through text chunks in order of most matching keywords chunk_indices_count: dict[str, int] = defaultdict(int) - keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] - for keyword in keywords: + keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] + for keyword in keywords_list: for node_id in keyword_table[keyword]: chunk_indices_count[node_id] += 1 @@ -215,7 +224,7 @@ class Jieba(BaseKeyword): def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) def multi_create_segment_keywords(self, pre_segment_data_list: list): @@ -226,17 +235,19 @@ class Jieba(BaseKeyword): if pre_segment_data["keywords"]: segment.keywords = pre_segment_data["keywords"] keyword_table = self._add_text_to_keyword_table( - keyword_table, segment.index_node_id, pre_segment_data["keywords"] + keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, segment.index_node_id, list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 4b1ade8e3f..8b17e8dc0a 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,18 +1,19 @@ import re from typing import Optional -import jieba -from jieba.analyse import default_tfidf - -from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - class JiebaKeywordTableHandler: def __init__(self): - default_tfidf.stop_words = STOPWORDS + import jieba.analyse # type: ignore + + from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + jieba.analyse.default_tfidf.stop_words = STOPWORDS def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" + import jieba # type: ignore + keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, @@ -22,6 +23,8 @@ class JiebaKeywordTableHandler: def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" + from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + results = set() for token in tokens: results.add(token) diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index be00687abd..b261b40b72 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -37,6 +37,8 @@ class BaseKeyword(ABC): def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: @@ -45,4 +47,4 @@ class BaseKeyword(ABC): return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 18f8d4e839..34343ad60e 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,6 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -31,7 +32,7 @@ class RetrievalService: top_k: int, score_threshold: Optional[float] = 0.0, reranking_model: Optional[dict] = None, - reranking_mode: Optional[str] = "reranking_model", + reranking_mode: str = "reranking_model", weights: Optional[dict] = None, ): if not query: @@ -42,15 +43,15 @@ class RetrievalService: if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] - all_documents = [] - threads = [] - exceptions = [] + all_documents: list[Document] = [] + threads: list[threading.Thread] = [] + exceptions: list[str] = [] # retrieval_model source with keyword if retrieval_method == "keyword_search": keyword_thread = threading.Thread( target=RetrievalService.keyword_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -65,7 +66,7 @@ class RetrievalService: embedding_thread = threading.Thread( target=RetrievalService.embedding_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -84,7 +85,7 @@ class RetrievalService: full_text_index_thread = threading.Thread( target=RetrievalService.full_text_index_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "retrieval_method": retrieval_method, @@ -124,7 +125,7 @@ class RetrievalService: if not dataset: return [] all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, dataset_id, query, external_retrieval_model + dataset.tenant_id, dataset_id, query, external_retrieval_model or {} ) return all_documents @@ -135,6 +136,8 @@ class RetrievalService: with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") keyword = Keyword(dataset=dataset) @@ -159,6 +162,8 @@ class RetrievalService: with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector = Vector(dataset=dataset) @@ -209,6 +214,8 @@ class RetrievalService: with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector_processor = Vector( dataset=dataset, diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 09104ae422..603d3fdbcd 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -17,12 +17,19 @@ from models.dataset import Dataset class AnalyticdbVector(BaseVector): def __init__( - self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig + self, + collection_name: str, + api_config: AnalyticdbVectorOpenAPIConfig | None, + sql_config: AnalyticdbVectorBySqlConfig | None, ): super().__init__(collection_name) if api_config is not None: - self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) + self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI( + collection_name, api_config + ) else: + if sql_config is None: + raise ValueError("Either api_config or sql_config must be provided") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) def get_type(self) -> str: @@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector): self.analyticdb_vector._create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - self.analyticdb_vector.add_texts(texts, embeddings) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + self.analyticdb_vector.add_texts(documents, embeddings) def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) @@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): if dify_config.ANALYTICDB_HOST is None: # implemented through OpenAPI apiConfig = AnalyticdbVectorOpenAPIConfig( - access_key_id=dify_config.ANALYTICDB_KEY_ID, - access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, - region_id=dify_config.ANALYTICDB_REGION_ID, - instance_id=dify_config.ANALYTICDB_INSTANCE_ID, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, - namespace=dify_config.ANALYTICDB_NAMESPACE, + access_key_id=dify_config.ANALYTICDB_KEY_ID or "", + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "", + region_id=dify_config.ANALYTICDB_REGION_ID or "", + instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "", + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", + namespace=dify_config.ANALYTICDB_NAMESPACE or "", namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ) sqlConfig = None @@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): sqlConfig = AnalyticdbVectorBySqlConfig( host=dify_config.ANALYTICDB_HOST, port=dify_config.ANALYTICDB_PORT, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, - namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace=dify_config.ANALYTICDB_NAMESPACE or "", ) apiConfig = None return AnalyticdbVector( diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 05e0ebc54f..095752ea8e 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: str = (None,) + namespace_password: Optional[str] = None metrics: str = "cosine" read_timeout: int = 60000 @@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): class AnalyticdbVectorOpenAPI: def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): try: - from alibabacloud_gpdb20160503.client import Client - from alibabacloud_tea_openapi import models as open_api_models + from alibabacloud_gpdb20160503.client import Client # type: ignore + from alibabacloud_tea_openapi import models as open_api_models # type: ignore except: raise ImportError(_import_err_msg) self._collection_name = collection_name.lower() @@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI: redis_client.set(database_exist_cache_key, 1, ex=3600) def _initialize_vector_database(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, @@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException + from Tea.exceptions import TeaException # type: ignore try: request = gpdb_20160503_models.DescribeNamespaceRequest( @@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI: rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): - metadata = { - "ref_doc_id": doc.metadata["doc_id"], - "page_content": doc.page_content, - "metadata_": json.dumps(doc.metadata), - } - rows.append( - gpdb_20160503_models.UpsertCollectionDataRequestRows( - vector=embedding, - metadata=metadata, + if doc.metadata is not None: + metadata = { + "ref_doc_id": doc.metadata["doc_id"], + "page_content": doc.page_content, + "metadata_": json.dumps(doc.metadata), + } + rows.append( + gpdb_20160503_models.UpsertCollectionDataRequestRows( + vector=embedding, + metadata=metadata, + ) ) - ) request = gpdb_20160503_models.UpsertCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI: metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI: metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index e474db5cb2..4d8f792941 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -3,8 +3,8 @@ import uuid from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from core.rag.models.document import Document @@ -75,6 +75,7 @@ class AnalyticdbVectorBySql: @contextmanager def _get_cursor(self): + assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() try: @@ -156,16 +157,17 @@ class AnalyticdbVectorBySql: VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); """ for i, doc in enumerate(documents): - values.append( - ( - id_prefix + str(i), - doc.metadata.get("doc_id", str(uuid.uuid4())), - embeddings[i], - doc.page_content, - json.dumps(doc.metadata), - doc.page_content, + if doc.metadata is not None: + values.append( + ( + id_prefix + str(i), + doc.metadata.get("doc_id", str(uuid.uuid4())), + embeddings[i], + doc.page_content, + json.dumps(doc.metadata), + doc.page_content, + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_batch(cur, sql, values) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index eb78e8aa69..85596ad20e 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -5,13 +5,13 @@ from typing import Any import numpy as np from pydantic import BaseModel, model_validator -from pymochow import MochowClient -from pymochow.auth.bce_credentials import BceCredentials -from pymochow.configuration import Configuration -from pymochow.exception import ServerError -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState -from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex -from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row +from pymochow import MochowClient # type: ignore +from pymochow.auth.bce_credentials import BceCredentials # type: ignore +from pymochow.configuration import Configuration # type: ignore +from pymochow.exception import ServerError # type: ignore +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -75,7 +75,7 @@ class BaiduVector(BaseVector): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + metadatas = [doc.metadata for doc in documents if doc.metadata is not None] total_count = len(documents) batch_size = 1000 @@ -84,6 +84,8 @@ class BaiduVector(BaseVector): for start in range(0, total_count, batch_size): end = min(start + batch_size, total_count) rows = [] + assert len(metadatas) == total_count, "metadatas length should be equal to total_count" + # FIXME do you need this assert? for i in range(start, end, 1): row = Row( id=metadatas[i].get("doc_id", str(uuid.uuid4())), @@ -136,7 +138,7 @@ class BaiduVector(BaseVector): # baidu vector database doesn't support bm25 search on current version return [] - def _get_search_res(self, res, score_threshold): + def _get_search_res(self, res, score_threshold) -> list[Document]: docs = [] for row in res.rows: row_data = row.get("row", {}) @@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory): return BaiduVector( collection_name=collection_name, config=BaiduConfig( - endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "", connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, - account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, - api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, - database=dify_config.BAIDU_VECTOR_DB_DATABASE, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "", + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "", + database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", shard=dify_config.BAIDU_VECTOR_DB_SHARD, replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, ), diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index a9e1486edd..0eab01b507 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -71,11 +71,13 @@ class ChromaVector(BaseVector): metadatas = [d.metadata for d in documents] collection = self._client.get_or_create_collection(self._collection_name) - collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) + # FIXME: chromadb using numpy array, fix the type error later + collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {"$eq": value}}) + # FIXME: fix the type error later + collection.delete(where={key: {"$eq": value}}) # type: ignore def delete(self): self._client.delete_collection(self._collection_name) @@ -94,15 +96,19 @@ class ChromaVector(BaseVector): results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results["ids"][0] - documents: list[str] = results["documents"][0] - metadatas: dict[str, Any] = results["metadatas"][0] - distances: list[float] = results["distances"][0] + # Check if results contain data + if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]: + return [] + + ids = results["ids"][0] + documents = results["documents"][0] + metadatas = results["metadatas"][0] + distances = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] - metadata = metadatas[index] + metadata = dict(metadatas[index]) if distance >= score_threshold: metadata["score"] = distance doc = Document( @@ -111,7 +117,7 @@ class ChromaVector(BaseVector): ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory): return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=dify_config.CHROMA_HOST, + host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index d26726e864..68a9952789 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -5,14 +5,14 @@ import uuid from datetime import timedelta from typing import Any -from couchbase import search -from couchbase.auth import PasswordAuthenticator -from couchbase.cluster import Cluster -from couchbase.management.search import SearchIndex +from couchbase import search # type: ignore +from couchbase.auth import PasswordAuthenticator # type: ignore +from couchbase.cluster import Cluster # type: ignore +from couchbase.management.search import SearchIndex # type: ignore # needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. -from couchbase.options import ClusterOptions, SearchOptions -from couchbase.vector_search import VectorQuery, VectorSearch +from couchbase.options import ClusterOptions, SearchOptions # type: ignore +from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore from flask import current_app from pydantic import BaseModel, model_validator @@ -231,7 +231,7 @@ class CouchbaseVector(BaseVector): # Pass the id as a parameter to the query result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() for row in result: - return row["count"] > 0 + return bool(row["count"] > 0) return False # Return False if no rows are returned def delete_by_ids(self, ids: list[str]) -> None: @@ -369,10 +369,10 @@ class CouchbaseVectorFactory(AbstractVectorFactory): return CouchbaseVector( collection_name=collection_name, config=CouchbaseConfig( - connection_string=config.get("COUCHBASE_CONNECTION_STRING"), - user=config.get("COUCHBASE_USER"), - password=config.get("COUCHBASE_PASSWORD"), - bucket_name=config.get("COUCHBASE_BUCKET_NAME"), - scope_name=config.get("COUCHBASE_SCOPE_NAME"), + connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""), + user=config.get("COUCHBASE_USER", ""), + password=config.get("COUCHBASE_PASSWORD", ""), + bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""), + scope_name=config.get("COUCHBASE_SCOPE_NAME", ""), ), ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index b08811a021..8661828dc2 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any, Optional, cast from urllib.parse import urlparse import requests @@ -70,7 +70,7 @@ class ElasticSearchVector(BaseVector): def _get_version(self) -> str: info = self._client.info() - return info["version"]["number"] + return cast(str, info["version"]["number"]) def _check_version(self): if self._version < "8.0.0": @@ -135,7 +135,8 @@ class ElasticSearchVector(BaseVector): for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -156,12 +157,15 @@ class ElasticSearchVector(BaseVector): return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -208,10 +212,10 @@ class ElasticSearchVectorFactory(AbstractVectorFactory): return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST"), - port=config.get("ELASTICSEARCH_PORT"), - username=config.get("ELASTICSEARCH_USERNAME"), - password=config.get("ELASTICSEARCH_PASSWORD"), + host=config.get("ELASTICSEARCH_HOST", "localhost"), + port=config.get("ELASTICSEARCH_PORT", 9200), + username=config.get("ELASTICSEARCH_USERNAME", ""), + password=config.get("ELASTICSEARCH_PASSWORD", ""), ), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index aa2bb01842..d7a14207e9 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -42,18 +42,18 @@ class LindormVectorStoreConfig(BaseModel): return values def to_opensearch_params(self) -> dict[str, Any]: - params = {"hosts": self.hosts} + params: dict[str, Any] = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params class LindormVectorStore(BaseVector): - def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): + def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs): self._routing = None self._routing_field = None - if config.using_ugc: - routing_value: str = kwargs.get("routing_value") + if using_ugc: + routing_value: str | None = kwargs.get("routing_value") if routing_value is None: raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") self._routing = routing_value.lower() @@ -64,7 +64,7 @@ class LindormVectorStore(BaseVector): super().__init__(collection_name.lower()) self._client_config = config self._client = OpenSearch(**config.to_opensearch_params()) - self._using_ugc = config.using_ugc + self._using_ugc = using_ugc self.kwargs = kwargs def get_type(self) -> str: @@ -87,14 +87,15 @@ class LindormVectorStore(BaseVector): "_id": uuids[i], } } - action_values = { + action_values: dict[str, Any] = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, } if self._using_ugc: action_header["index"]["routing"] = self._routing - action_values[self._routing_field] = self._routing + if self._routing_field is not None: + action_values[self._routing_field] = self._routing actions.append(action_header) actions.append(action_values) response = self._client.bulk(actions) @@ -105,7 +106,9 @@ class LindormVectorStore(BaseVector): self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} + query: dict[str, Any] = { + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + } if self._using_ugc: query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) @@ -191,7 +194,8 @@ class LindormVectorStore(BaseVector): for doc, score in docs_and_scores: score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -366,6 +370,7 @@ def default_text_search_query( routing_field: Optional[str] = None, **kwargs, ) -> dict: + query_clause: dict[str, Any] = {} if routing is not None: query_clause = { "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} @@ -386,7 +391,7 @@ def default_text_search_query( else: must = [query_clause] - boolean_query = {"must": must} + boolean_query: dict[str, Any] = {"must": must} if must_not: if not isinstance(must_not, list): @@ -426,7 +431,7 @@ def default_vector_search_query( filter_type = "post_filter" if filter_type is None else filter_type if not isinstance(filters, list): raise RuntimeError(f"unexpected filter with {type(filters)}") - final_ext = {"lvector": {}} + final_ext: dict[str, Any] = {"lvector": {}} if min_score != "0.0": final_ext["lvector"]["min_score"] = min_score if ef_search: @@ -438,7 +443,7 @@ def default_vector_search_query( if client_refactor: final_ext["lvector"]["client_refactor"] = client_refactor - search_query = { + search_query: dict[str, Any] = { "size": k, "_source": True, # force return '_source' "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, @@ -446,8 +451,8 @@ def default_vector_search_query( if filters is not None: # when using filter, transform filter from List[Dict] to Dict as valid format - filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict if filter_type: final_ext["lvector"]["filter_type"] = filter_type @@ -459,20 +464,26 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: lindorm_config = LindormVectorStoreConfig( - hosts=dify_config.LINDORM_URL, + hosts=dify_config.LINDORM_URL or "", username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, using_ugc=dify_config.USING_UGC_INDEX, ) using_ugc = dify_config.USING_UGC_INDEX + if using_ugc is None: + raise ValueError("USING_UGC_INDEX is not set") routing_value = None if dataset.index_struct: - if using_ugc: + # if an existed record's index_struct_dict doesn't contain using_ugc field, + # it actually stores in the normal index format + stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False) + using_ugc = stored_in_ugc + if stored_in_ugc: dimension = dataset.index_struct_dict["dimension"] index_type = dataset.index_struct_dict["index_type"] distance_type = dataset.index_struct_dict["distance_type"] - index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"] + index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" else: index_name = dataset.index_struct_dict["vector_store"]["class_prefix"] else: @@ -487,6 +498,7 @@ class LindormVectorStoreFactory(AbstractVectorFactory): "index_type": index_type, "dimension": dimension, "distance_type": distance_type, + "using_ugc": using_ugc, } dataset.index_struct = json.dumps(index_struct_dict) if using_ugc: @@ -494,4 +506,4 @@ class LindormVectorStoreFactory(AbstractVectorFactory): routing_value = class_prefix else: index_name = class_prefix - return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value) + return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5a263d6e78..9b029ffc19 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -3,8 +3,8 @@ import logging from typing import Any, Optional from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException -from pymilvus.milvus_client import IndexParams +from pymilvus import MilvusClient, MilvusException # type: ignore +from pymilvus.milvus_client import IndexParams # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -54,14 +54,14 @@ class MilvusVector(BaseVector): self._client_config = config self._client = self._init_client(config) self._consistency_level = "Session" - self._fields = [] + self._fields: list[str] = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -161,8 +161,8 @@ class MilvusVector(BaseVector): return # Grab the existing collection if it exists if not self._client.has_collection(self._collection_name): - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore + from pymilvus.orm.types import infer_dtype_bydata # type: ignore # Determine embedding dim dim = len(embeddings[0]) @@ -217,10 +217,10 @@ class MilvusVectorFactory(AbstractVectorFactory): return MilvusVector( collection_name=collection_name, config=MilvusConfig( - uri=dify_config.MILVUS_URI, - token=dify_config.MILVUS_TOKEN, - user=dify_config.MILVUS_USER, - password=dify_config.MILVUS_PASSWORD, - database=dify_config.MILVUS_DATABASE, + uri=dify_config.MILVUS_URI or "", + token=dify_config.MILVUS_TOKEN or "", + user=dify_config.MILVUS_USER or "", + password=dify_config.MILVUS_PASSWORD or "", + database=dify_config.MILVUS_DATABASE or "", ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b7b6b803ad..e63e1f522b 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -74,15 +74,16 @@ class MyScaleVector(BaseVector): columns = ["id", "text", "vector", "metadata"] values = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - row = ( - doc_id, - self.escape_str(doc.page_content), - embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {}, - ) - values.append(str(row)) - ids.append(doc_id) + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + row = ( + doc_id, + self.escape_str(doc.page_content), + embeddings[i], + json.dumps(doc.metadata) if doc.metadata else {}, + ) + values.append(str(row)) + ids.append(doc_id) sql = f""" INSERT INTO {self._config.database}.{self._collection_name} ({",".join(columns)}) VALUES {",".join(values)} diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index c44338d42a..957c799a60 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,7 +4,7 @@ import math from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient +from pyobvector import VECTOR, ObVecClient # type: ignore from sqlalchemy import JSON, Column, String, func from sqlalchemy.dialects.mysql import LONGTEXT @@ -131,7 +131,7 @@ class OceanBaseVector(BaseVector): def text_exists(self, id: str) -> bool: cur = self._client.get(table_name=self._collection_name, id=id) - return cur.rowcount != 0 + return bool(cur.rowcount != 0) def delete_by_ids(self, ids: list[str]) -> None: self._client.delete(table_name=self._collection_name, ids=ids) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 7a976d7c3c..72a1502205 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -66,7 +66,7 @@ class OpenSearchVector(BaseVector): return VectorType.OPENSEARCH def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings) @@ -244,7 +244,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( - host=dify_config.OPENSEARCH_HOST, + host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 71c58c9d0c..dfff3563c3 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -5,11 +5,9 @@ import uuid from contextlib import contextmanager from typing import Any -import jieba.posseg as pseg -import nltk +import jieba.posseg as pseg # type: ignore import numpy import oracledb -from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config @@ -90,12 +88,11 @@ class OracleVector(BaseVector): def numpy_converter_out(self, value): if value.typecode == "b": - dtype = numpy.int8 + return numpy.array(value, copy=False, dtype=numpy.int8) elif value.typecode == "f": - dtype = numpy.float32 + return numpy.array(value, copy=False, dtype=numpy.float32) else: - dtype = numpy.float64 - return numpy.array(value, copy=False, dtype=dtype) + return numpy.array(value, copy=False, dtype=numpy.float64) def output_type_handler(self, cursor, metadata): if metadata.type_code is oracledb.DB_TYPE_VECTOR: @@ -137,17 +134,18 @@ class OracleVector(BaseVector): values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - # array.array("f", embeddings[i]), - numpy.array(embeddings[i]), + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + # array.array("f", embeddings[i]), + numpy.array(embeddings[i]), + ) ) - ) # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: cur.executemany( @@ -202,6 +200,10 @@ class OracleVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # lazy import + import nltk # type: ignore + from nltk.corpus import stopwords # type: ignore + top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later score_threshold = float(kwargs.get("score_threshold") or 0.0) @@ -283,10 +285,10 @@ class OracleVectorFactory(AbstractVectorFactory): return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=dify_config.ORACLE_HOST, + host=dify_config.ORACLE_HOST or "localhost", port=dify_config.ORACLE_PORT, - user=dify_config.ORACLE_USER, - password=dify_config.ORACLE_PASSWORD, - database=dify_config.ORACLE_DATABASE, + user=dify_config.ORACLE_USER or "system", + password=dify_config.ORACLE_PASSWORD or "oracle", + database=dify_config.ORACLE_DATABASE or "orcl", ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 7cbbdcc81f..221bc68d68 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -4,7 +4,7 @@ from typing import Any from uuid import UUID, uuid4 from numpy import ndarray -from pgvecto_rs.sqlalchemy import VECTOR +from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator from sqlalchemy import Float, String, create_engine, insert, select, text from sqlalchemy import text as sql_text @@ -58,7 +58,7 @@ class PGVectoRS(BaseVector): with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) session.commit() - self._fields = [] + self._fields: list[str] = [] class _Table(CollectionORM): __tablename__ = collection_name @@ -222,11 +222,11 @@ class PGVectoRSFactory(AbstractVectorFactory): return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=dify_config.PGVECTO_RS_HOST, - port=dify_config.PGVECTO_RS_PORT, - user=dify_config.PGVECTO_RS_USER, - password=dify_config.PGVECTO_RS_PASSWORD, - database=dify_config.PGVECTO_RS_DATABASE, + host=dify_config.PGVECTO_RS_HOST or "localhost", + port=dify_config.PGVECTO_RS_PORT or 5432, + user=dify_config.PGVECTO_RS_USER or "postgres", + password=dify_config.PGVECTO_RS_PASSWORD or "", + database=dify_config.PGVECTO_RS_DATABASE or "postgres", ), dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 40a9cdd136..271281ca7e 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -3,8 +3,8 @@ import uuid from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -98,16 +98,17 @@ class PGVector(BaseVector): values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - embeddings[i], + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_values( cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values @@ -216,11 +217,11 @@ class PGVectorFactory(AbstractVectorFactory): return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=dify_config.PGVECTOR_HOST, + host=dify_config.PGVECTOR_HOST or "localhost", port=dify_config.PGVECTOR_PORT, - user=dify_config.PGVECTOR_USER, - password=dify_config.PGVECTOR_PASSWORD, - database=dify_config.PGVECTOR_DATABASE, + user=dify_config.PGVECTOR_USER or "postgres", + password=dify_config.PGVECTOR_PASSWORD or "", + database=dify_config.PGVECTOR_DATABASE or "postgres", min_connection=dify_config.PGVECTOR_MIN_CONNECTION, max_connection=dify_config.PGVECTOR_MAX_CONNECTION, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 3811458e02..6e94cb69db 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -51,6 +51,8 @@ class QdrantConfig(BaseModel): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): + if not self.root_path: + raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) return {"path": path} @@ -149,9 +151,12 @@ class QdrantVector(BaseVector): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] - added_ids = [] - for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + # Filter out None values from metadatas list to match expected type + filtered_metadatas = [m for m in metadatas if m is not None] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, filtered_metadatas, uuids, 64, self._group_id + ): self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) @@ -194,7 +199,7 @@ class QdrantVector(BaseVector): batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", # Ensure group_id is never None Field.GROUP_KEY.value, ), ) @@ -337,18 +342,20 @@ class QdrantVector(BaseVector): ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=dify_config.QDRANT_URL, + endpoint=dify_config.QDRANT_URL or "", api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.config.root_path, + root_path=str(current_app.config.root_path), timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index f373dcfeab..a3a20448ff 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -3,7 +3,7 @@ import uuid from typing import Any, Optional from pydantic import BaseModel, model_validator -from sqlalchemy import Column, Sequence, String, Table, create_engine, insert +from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session @@ -58,14 +58,14 @@ class RelytVector(BaseVector): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self.client = create_engine(self._url) - self._fields = [] + self._fields: list[str] = [] self._group_id = group_id def get_type(self) -> str: return VectorType.RELYT - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = {} + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: + index_params: dict[str, Any] = {} metadatas = [d.metadata for d in texts] self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) @@ -107,10 +107,10 @@ class RelytVector(BaseVector): redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - from pgvecto_rs.sqlalchemy import VECTOR + from pgvecto_rs.sqlalchemy import VECTOR # type: ignore ids = [str(uuid.uuid1()) for _ in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] for metadata in metadatas: metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] @@ -242,10 +242,6 @@ class RelytVector(BaseVector): filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided - try: - from sqlalchemy.engine import Row - except ImportError: - raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: @@ -275,7 +271,7 @@ class RelytVector(BaseVector): # Execute the query and fetch the results with self.client.connect() as conn: - results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall() + results = conn.execute(sql_text(sql_query), params).fetchall() documents_with_scores = [ ( @@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory): return RelytVector( collection_name=collection_name, config=RelytConfig( - host=dify_config.RELYT_HOST, + host=dify_config.RELYT_HOST or "localhost", port=dify_config.RELYT_PORT, - user=dify_config.RELYT_USER, - password=dify_config.RELYT_PASSWORD, - database=dify_config.RELYT_DATABASE, + user=dify_config.RELYT_USER or "", + password=dify_config.RELYT_PASSWORD or "", + database=dify_config.RELYT_DATABASE or "default", ), group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index f971a9c5eb..c15f4b229f 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -2,10 +2,10 @@ import json from typing import Any, Optional from pydantic import BaseModel -from tcvectordb import VectorDBClient -from tcvectordb.model import document, enum -from tcvectordb.model import index as vdb_index -from tcvectordb.model.document import Filter +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model import document, enum # type: ignore +from tcvectordb.model import index as vdb_index # type: ignore +from tcvectordb.model.document import Filter # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -25,8 +25,8 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = (1,) - replicas: int = (2,) + shard: int = 1 + replicas: int = 2 def to_tencent_params(self): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} @@ -120,15 +120,15 @@ class TencentVector(BaseVector): metadatas = [doc.metadata for doc in documents] total_count = len(embeddings) docs = [] - for id in range(0, total_count): + for i in range(0, total_count): if metadatas is None: continue - metadata = json.dumps(metadatas[id]) + metadata = metadatas[i] or {} doc = document.Document( - id=metadatas[id]["doc_id"], - vector=embeddings[id], - text=texts[id], - metadata=metadata, + id=metadata.get("doc_id"), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadata), ) docs.append(doc) self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout) @@ -159,8 +159,8 @@ class TencentVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def _get_search_res(self, res, score_threshold): - docs = [] + def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: + docs: list[Document] = [] if res is None or len(res) == 0: return docs @@ -193,7 +193,7 @@ class TencentVectorFactory(AbstractVectorFactory): return TencentVector( collection_name=collection_name, config=TencentConfig( - url=dify_config.TENCENT_VECTOR_DB_URL, + url=dify_config.TENCENT_VECTOR_DB_URL or "", api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, username=dify_config.TENCENT_VECTOR_DB_USERNAME, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index cfd47aac5b..19c5579a68 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -54,7 +54,10 @@ class TidbOnQdrantConfig(BaseModel): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): - path = os.path.join(self.root_path, path) + if self.root_path: + path = os.path.join(self.root_path, path) + else: + raise ValueError("root_path is required") return {"path": path} else: @@ -157,7 +160,7 @@ class TidbOnQdrantVector(BaseVector): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] added_ids = [] for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): @@ -203,7 +206,7 @@ class TidbOnQdrantVector(BaseVector): batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", Field.GROUP_KEY.value, ), ) @@ -334,18 +337,20 @@ class TidbOnQdrantVector(BaseVector): ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -427,12 +432,12 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): else: new_cluster = TidbService.create_tidb_serverless_cluster( - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + dify_config.TIDB_PROJECT_ID or "", + dify_config.TIDB_API_URL or "", + dify_config.TIDB_IAM_API_URL or "", + dify_config.TIDB_PUBLIC_KEY or "", + dify_config.TIDB_PRIVATE_KEY or "", + dify_config.TIDB_REGION or "", ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], @@ -464,9 +469,9 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL, + endpoint=dify_config.TIDB_ON_QDRANT_URL or "", api_key=TIDB_ON_QDRANT_API_KEY, - root_path=config.root_path, + root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 8dd5922ad0..0a48c79511 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -146,7 +146,7 @@ class TidbService: iam_url: str, public_key: str, private_key: str, - ) -> list[dict]: + ): """ Update the status of a new TiDB Serverless cluster. :param project_id: The project ID of the TiDB Cloud project (required). @@ -159,7 +159,6 @@ class TidbService: :return: The response from the API. """ - clusters = [] tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} @@ -169,7 +168,6 @@ class TidbService: if response.status_code == 200: response_data = response.json() - cluster_infos = [] for item in response_data["clusters"]: state = item["state"] userPrefix = item["userPrefix"] @@ -236,16 +234,17 @@ class TidbService: cluster_infos = [] for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" - password = redis_client.get(cache_key) - if not password: + cached_password = redis_client.get(cache_key) + if not cached_password: continue cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", - "password": password.decode("utf-8"), + "password": cached_password.decode("utf-8"), } cluster_infos.append(cluster_info) return cluster_infos else: response.raise_for_status() + return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 39ab6ea71e..be3a417390 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -49,7 +49,7 @@ class TiDBVector(BaseVector): return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: - from tidb_vector.sqlalchemy import VectorType + from tidb_vector.sqlalchemy import VectorType # type: ignore return Table( self._collection_name, @@ -241,11 +241,11 @@ class TiDBVectorFactory(AbstractVectorFactory): return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=dify_config.TIDB_VECTOR_HOST, - port=dify_config.TIDB_VECTOR_PORT, - user=dify_config.TIDB_VECTOR_USER, - password=dify_config.TIDB_VECTOR_PASSWORD, - database=dify_config.TIDB_VECTOR_DATABASE, + host=dify_config.TIDB_VECTOR_HOST or "", + port=dify_config.TIDB_VECTOR_PORT or 0, + user=dify_config.TIDB_VECTOR_USER or "", + password=dify_config.TIDB_VECTOR_PASSWORD or "", + database=dify_config.TIDB_VECTOR_DATABASE or "", program_name=dify_config.APPLICATION_NAME, ), ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 22e191340d..edfce2edd8 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -51,15 +51,16 @@ class BaseVector(ABC): def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): - doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if text.metadata and "doc_id" in text.metadata: + doc_id = text.metadata["doc_id"] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 6d2e04fc02..523fa80f12 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -193,10 +193,13 @@ class Vector: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if doc_id: + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 4f927f2899..9de8761a91 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -2,7 +2,7 @@ import json from typing import Any from pydantic import BaseModel -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Data, DistanceType, Field, @@ -121,11 +121,12 @@ class VikingDBVector(BaseVector): for i, page_content in enumerate(page_contents): metadata = {} if metadatas is not None: - for key, val in metadatas[i].items(): + for key, val in (metadatas[i] or {}).items(): metadata[key] = val + # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, vdb_Field.CONTENT_KEY.value: page_content, vdb_Field.METADATA_KEY.value: json.dumps(metadata), @@ -178,7 +179,7 @@ class VikingDBVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(results, score_threshold) - def _get_search_res(self, results, score_threshold): + def _get_search_res(self, results, score_threshold) -> list[Document]: if len(results) == 0: return [] @@ -191,7 +192,7 @@ class VikingDBVector(BaseVector): metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 649cfbfea8..68d043a19f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -3,7 +3,7 @@ import json from typing import Any, Optional import requests -import weaviate +import weaviate # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -107,7 +107,8 @@ class WeaviateVector(BaseVector): for i, text in enumerate(texts): data_properties = {Field.TEXT_KEY.value: text} if metadatas is not None: - for key, val in metadatas[i].items(): + # metadata maybe None + for key, val in (metadatas[i] or {}).items(): data_properties[key] = self._json_serializable(val) batch.add_data_object( @@ -208,10 +209,11 @@ class WeaviateVector(BaseVector): score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata["score"] = score - docs.append(doc) + if doc.metadata is not None: + doc.metadata["score"] = score + docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -275,7 +277,7 @@ class WeaviateVectorFactory(AbstractVectorFactory): return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=dify_config.WEAVIATE_ENDPOINT, + endpoint=dify_config.WEAVIATE_ENDPOINT or "", api_key=dify_config.WEAVIATE_API_KEY, batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 306f0c27ea..6d16a9bdc2 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -89,6 +89,9 @@ class DatasetDocumentStore: if not isinstance(doc, Document): raise ValueError("doc must be a Document") + if doc.metadata is None: + raise ValueError("doc.metadata must be a dict") + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it @@ -179,10 +182,10 @@ class DatasetDocumentStore: if document_segment is None: return None + data: Optional[str] = document_segment.index_node_hash + return data - return document_segment.index_node_hash - - def get_document_segment(self, doc_id: str) -> DocumentSegment: + def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index fc8e0440c3..a2c8737da7 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Optional, cast +from typing import Any, Optional, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -27,7 +27,7 @@ class CacheEmbedding(Embeddings): def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] + text_embeddings: list[Any] = [None for _ in range(len(texts))] embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) @@ -64,7 +64,13 @@ class CacheEmbedding(Embeddings): for vector in embedding_result.embeddings: try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning(f"Normalized embedding is nan: {normalized_embedding}") + continue embedding_queue_embeddings.append(normalized_embedding) except IntegrityError: db.session.rollback() @@ -72,8 +78,8 @@ class CacheEmbedding(Embeddings): logging.exception("Failed transform embedding") cache_embeddings = [] try: - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = n_embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( @@ -81,7 +87,7 @@ class CacheEmbedding(Embeddings): hash=hash, provider_name=self._model_instance.provider, ) - embedding_cache.set_embedding(embedding) + embedding_cache.set_embedding(n_embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) db.session.commit() @@ -110,7 +116,10 @@ class CacheEmbedding(Embeddings): ) embedding_results = embedding_result.embeddings[0] - embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") except Exception as ex: if dify_config.DEBUG: logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 3692b5d19d..7c00c668dd 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -14,7 +14,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Document = None + document: Optional[Document] = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index fc33165719..c444105bb5 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional +from typing import Optional, cast import pandas as pd from openpyxl import load_workbook @@ -47,7 +47,7 @@ class ExcelExtractor(BaseExtractor): for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): cell = sheet.cell( - row=index + 2, column=col_index + 1 + row=cast(int, index) + 2, column=col_index + 1 ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" @@ -60,8 +60,8 @@ class ExcelExtractor(BaseExtractor): elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") - for sheet_name in excel_file.sheet_names: - df = excel_file.parse(sheet_name=sheet_name) + for excel_sheet_name in excel_file.sheet_names: + df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) for _, row in df.iterrows(): diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 69659e3108..a473b3dfa7 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -10,6 +10,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor @@ -66,9 +67,13 @@ class ExtractProcessor: filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = "." + re.search(r"\.(\w+)$", filename).group(1) - - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + match = re.search(r"\.(\w+)$", filename) + if match: + suffix = "." + match.group(1) + else: + suffix = "" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: @@ -89,15 +94,20 @@ class ExtractProcessor: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: + assert extract_setting.upload_file is not None, "upload_file is required" upload_file: UploadFile = extract_setting.upload_file suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY + assert unstructured_api_url is not None, "unstructured_api_url is required" + assert unstructured_api_key is not None, "unstructured_api_key is required" + extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) @@ -156,6 +166,7 @@ class ExtractProcessor: extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.NOTION.value: + assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_obj_id=extract_setting.notion_info.notion_obj_id, @@ -165,6 +176,7 @@ class ExtractProcessor: ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 17c2087a0a..8ae4579c7c 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,5 +1,6 @@ import json import time +from typing import cast import requests @@ -20,9 +21,9 @@ class FirecrawlApp: json_data.update(params) response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: - response = response.json() - if response["success"] == True: - data = response["data"] + response_data = response.json() + if response_data["success"] == True: + data = response_data["data"] return { "title": data.get("metadata").get("title"), "description": data.get("metadata").get("description"), @@ -30,7 +31,7 @@ class FirecrawlApp: "markdown": data.get("markdown"), } else: - raise Exception(f'Failed to scrape URL. Error: {response["error"]}') + raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}') elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") @@ -46,9 +47,11 @@ class FirecrawlApp: response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: job_id = response.json().get("jobId") - return job_id + return cast(str, job_id) else: self._handle_error(response, "start crawl job") + # FIXME: unreachable code for mypy + return "" # unreachable def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() @@ -64,9 +67,9 @@ class FirecrawlApp: for item in data: if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - "title": item.get("metadata").get("title"), - "description": item.get("metadata").get("description"), - "source_url": item.get("metadata").get("sourceURL"), + "title": item.get("metadata", {}).get("title"), + "description": item.get("metadata", {}).get("description"), + "source_url": item.get("metadata", {}).get("sourceURL"), "markdown": item.get("markdown"), } url_data_list.append(url_data) @@ -92,6 +95,8 @@ class FirecrawlApp: else: self._handle_error(response, "check crawl status") + # FIXME: unreachable code for mypy + return {} # unreachable def _prepare_headers(self): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index 560c2d1d84..350b522347 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -23,6 +23,7 @@ class HtmlExtractor(BaseExtractor): return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: + text: str = "" with open(self._file_path, "rb") as fp: soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 87a4ce08bf..fdc2e46d14 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, cast import requests @@ -78,6 +78,7 @@ class NotionExtractor(BaseExtractor): def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" + assert self._notion_access_token is not None, "Notion access token is required" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ @@ -96,6 +97,7 @@ class NotionExtractor(BaseExtractor): for result in data["results"]: properties = result["properties"] data = {} + value: Any for property_name, property_value in properties.items(): type = property_value["type"] if type == "multi_select": @@ -130,6 +132,7 @@ class NotionExtractor(BaseExtractor): return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) @@ -184,6 +187,7 @@ class NotionExtractor(BaseExtractor): def _read_block(self, block_id: str, num_tabs: int = 0) -> str: """Read a block.""" + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) @@ -242,6 +246,7 @@ class NotionExtractor(BaseExtractor): def _read_table_rows(self, block_id: str) -> str: """Read table rows.""" + assert self._notion_access_token is not None, "Notion access token is required" done = False result_lines_arr = [] start_cursor = None @@ -296,7 +301,7 @@ class NotionExtractor(BaseExtractor): result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: DocumentModel): + def update_last_edited_time(self, document_model: Optional[DocumentModel]): if not document_model: return @@ -309,6 +314,7 @@ class NotionExtractor(BaseExtractor): db.session.commit() def get_notion_last_edited_time(self) -> str: + assert self._notion_access_token is not None, "Notion access token is required" obj_id = self._notion_obj_id page_type = self._notion_page_type if page_type == "database": @@ -330,7 +336,7 @@ class NotionExtractor(BaseExtractor): ) data = res.json() - return data["last_edited_time"] + return cast(str, data["last_edited_time"]) @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: @@ -349,4 +355,4 @@ class NotionExtractor(BaseExtractor): f"and notion workspace {notion_workspace_id}" ) - return data_source_binding.access_token + return cast(str, data_source_binding.access_token) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 57cb9610ba..89a7061c26 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" from collections.abc import Iterator -from typing import Optional +from typing import Optional, cast from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor): plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode("utf-8") + text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -53,7 +53,7 @@ class PdfExtractor(BaseExtractor): def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 + import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index bd669bbad3..9647dedfff 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,7 @@ import base64 import logging -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 35220b558a..80c29157aa 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -30,6 +30,9 @@ class UnstructuredEpubExtractor(BaseExtractor): if self._api_url: from unstructured.partition.api import partition_via_api + if self._api_key is None: + raise ValueError("api_key is required") + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: from unstructured.partition.epub import partition_epub diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index 0fdcd58b2e..e504d4bc23 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -27,9 +27,11 @@ class UnstructuredPPTExtractor(BaseExtractor): elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: raise NotImplementedError("Unstructured API Url is not configured") - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number + if page is None: + continue text = element.text if page in text_by_page: text_by_page[page] += "\n" + text diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index ab41290fbc..cefe72b290 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -29,14 +29,15 @@ class UnstructuredPPTXExtractor(BaseExtractor): from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path) - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number text = element.text - if page in text_by_page: - text_by_page[page] += "\n" + text - else: - text_by_page[page] = text + if page is not None: + if page in text_by_page: + text_by_page[page] += "\n" + text + else: + text_by_page[page] = text combined_texts = list(text_by_page.values()) documents = [] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0c38a9c076..c3161bc812 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -89,6 +89,8 @@ class WordExtractor(BaseExtractor): response = ssrf_proxy.get(url) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + if image_ext is None: + continue file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) @@ -97,6 +99,8 @@ class WordExtractor(BaseExtractor): continue else: image_ext = rel.target_ref.split(".")[-1] + if image_ext is None: + continue # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext @@ -226,6 +230,8 @@ class WordExtractor(BaseExtractor): if x_child is None: continue if x.tag.endswith("instrText"): + if x.text is None: + continue for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index be857bd122..7e5efdc66e 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -49,6 +49,7 @@ class BaseIndexProcessor(ABC): """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule["mode"] == "custom": # The user-defined segmentation rule rules = processing_rule["rules"] diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index 9b855ece2c..c5ba6295f3 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -9,7 +9,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess class IndexProcessorFactory: """IndexProcessorInit.""" - def __init__(self, index_type: str): + def __init__(self, index_type: str | None): self._index_type = index_type def init_index_processor(self) -> BaseIndexProcessor: diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a631f953ce..c66fa54d50 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -27,12 +27,13 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule", {}), + embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {})) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) @@ -41,8 +42,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 320f0157a1..20fd16e8f3 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -32,15 +32,16 @@ class QAIndexProcessor(BaseIndexProcessor): def transform(self, documents: list[Document], **kwargs) -> list[Document]: splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule") or {}, + embedding_model_instance=kwargs.get("embedding_model_instance"), ) # Split the text documents into nodes. - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {}) document.page_content = document_text # parse document to nodes @@ -50,8 +51,9 @@ class QAIndexProcessor(BaseIndexProcessor): if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -64,7 +66,7 @@ class QAIndexProcessor(BaseIndexProcessor): document_format_thread = threading.Thread( target=self._format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": kwargs.get("tenant_id"), "document_node": doc, "all_qa_documents": all_qa_documents, @@ -148,11 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor): qa_documents = [] for result in document_qa_list: qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6ae432a526..ac7a3f8bb8 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -30,7 +30,11 @@ class RerankModelRunner(BaseRerankRunner): doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -54,7 +58,8 @@ class RerankModelRunner(BaseRerankRunner): metadata=documents[result.index].metadata, provider=documents[result.index].provider, ) - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4719be012f..cbc96037bf 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -39,7 +39,7 @@ class WeightRerankRunner(BaseRerankRunner): unique_documents = [] doc_ids = set() for document in documents: - if document.metadata["doc_id"] not in doc_ids: + if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) @@ -56,10 +56,11 @@ class WeightRerankRunner(BaseRerankRunner): ) if score_threshold and score < score_threshold: continue - document.metadata["score"] = score - rerank_documents.append(document) + if document.metadata is not None: + document.metadata["score"] = score + rerank_documents.append(document) - rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True) + rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -76,8 +77,9 @@ class WeightRerankRunner(BaseRerankRunner): for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -162,7 +164,7 @@ class WeightRerankRunner(BaseRerankRunner): query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if "score" in document.metadata: + if document.metadata and "score" in document.metadata: query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 04c9244263..8a7172f27c 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,7 +1,7 @@ import math import threading from collections import Counter -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app @@ -32,7 +32,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -138,12 +138,12 @@ class DatasetRetrieval: user_from, available_datasets, query, - retrieve_config.top_k, - retrieve_config.score_threshold, - retrieve_config.rerank_mode, + retrieve_config.top_k or 0, + retrieve_config.score_threshold or 0, + retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled, + retrieve_config.reranking_enabled or True, message_id, ) @@ -298,10 +298,11 @@ class DatasetRetrieval: metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name results.append(document) else: retrieval_model_config = dataset.retrieval_model or default_retrieval_model @@ -323,7 +324,7 @@ class DatasetRetrieval: score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") + score_threshold = retrieval_model_config.get("score_threshold", 0.0) with measure_time() as timer: results = RetrievalService.retrieve( @@ -356,14 +357,14 @@ class DatasetRetrieval: score_threshold: float, reranking_mode: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + weights: Optional[dict[str, Any]] = None, reranking_enable: bool = True, message_id: Optional[str] = None, ): if not available_datasets: return [] threads = [] - all_documents = [] + all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets @@ -390,15 +391,18 @@ class DatasetRetrieval: "The configured knowledge base list have different embedding model, please set reranking model." ) if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: - weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider - weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + if weights is not None: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[ + 0 + ].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model for dataset in available_datasets: index_type = dataset.indexing_technique retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset.id, "query": query, "top_k": top_k, @@ -437,18 +441,19 @@ class DatasetRetrieval: """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() # get tracing instance trace_manager: TraceQueueManager | None = ( @@ -502,10 +507,11 @@ class DatasetRetrieval: metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name all_documents.append(document) else: # get retrieval model , if the model is not setting , using default @@ -637,10 +643,11 @@ class DatasetRetrieval: query_keywords = keyword_table_handler.extract_keywords(query, None) documents_keywords = [] for document in documents: - # get the document keywords - document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -698,8 +705,9 @@ class DatasetRetrieval: for document, score in zip(documents, similarities): # format document - document.metadata["score"] = score - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + if document.metadata is not None: + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return documents[:top_k] if top_k else documents def calculate_vector_score( @@ -707,10 +715,12 @@ class DatasetRetrieval: ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata["score"] >= score_threshold: + if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold): filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) + filter_documents = sorted( + filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True + ) return filter_documents[:top_k] if top_k else filter_documents diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 06147fe7b5..b008d0df9c 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -27,11 +28,14 @@ class FunctionCallMultiDatasetRouter: SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=dataset_tools, - stream=False, - model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=dataset_tools, + stream=False, + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + ), ) if result.message.tool_calls: # get retrieval model config diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 68fab0c127..05e8d043df 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,9 @@ from collections.abc import Generator, Sequence -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -92,6 +92,7 @@ class ReactMultiDatasetRouter: suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: + prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( query=query, @@ -149,12 +150,15 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=completion_param, - stop=stop, - stream=True, - user=user_id, + invoke_result = cast( + Generator[LLMResult, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=completion_param, + stop=stop, + stream=True, + user=user_id, + ), ) # handle invoke result @@ -172,7 +176,7 @@ class ReactMultiDatasetRouter: :return: """ model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = "" usage = None for result in invoke_result: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index e0cd3e53f1..91fb033c49 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", + allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): def _token_encoder(text: str) -> int: diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 89a9650b68..72c4700d5c 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -92,7 +92,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) - metadatas.append(doc.metadata) + metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: @@ -143,7 +143,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: - from transformers import PreTrainedTokenizerBase + from transformers import PreTrainedTokenizerBase # type: ignore if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") diff --git a/api/core/tools/builtin_tool/providers/_positions.py b/api/core/tools/builtin_tool/providers/_positions.py index 224b695ff9..44a90db038 100644 --- a/api/core/tools/builtin_tool/providers/_positions.py +++ b/api/core/tools/builtin_tool/providers/_positions.py @@ -5,7 +5,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity class BuiltinToolProviderSort: - _position = {} + _position: dict[str, int] = {} @classmethod def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]: diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index ae5e089d02..80fff0c084 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -23,8 +23,10 @@ class TTSTool(BuiltinTool): provider, model = tool_parameters.get("model").split("#") # type: ignore voice = tool_parameters.get(f"voice#{provider}#{model}") model_manager = ModelManager() + if not self.runtime: + raise ValueError("Runtime is required") model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -47,8 +49,11 @@ class TTSTool(BuiltinTool): ) def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + tid: str = self.runtime.tenant_id or "" + models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts") items = [] for provider_model in models: provider = provider_model.provider @@ -68,6 +73,8 @@ class TTSTool(BuiltinTool): ToolParameter( name=f"voice#{provider}#{model}", label=I18nObject(en_US=f"Voice of {model}({provider})"), + human_description=I18nObject(en_US=f"Select a voice for {model} model"), + placeholder=I18nObject(en_US="Select a voice"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, options=[ @@ -89,6 +96,7 @@ class TTSTool(BuiltinTool): type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=True, + placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"), options=options, ), ) diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index a51813ba40..0dfe9a37c3 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -49,9 +49,12 @@ class BuiltinTool(Tool): :return: the model result """ # invoke model + if self.runtime is None or self.identity is None: + raise ValueError("runtime and identity are required") + return ModelInvocationUtils.invoke( user_id=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", tool_type="builtin", tool_name=self.entity.identity.name, prompt_messages=prompt_messages, @@ -67,8 +70,11 @@ class BuiltinTool(Tool): :param model_config: the model config :return: the max tokens """ + if self.runtime is None: + raise ValueError("runtime is required") + return ModelInvocationUtils.get_max_llm_context_tokens( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -78,7 +84,12 @@ class BuiltinTool(Tool): :param prompt_messages: the prompt messages :return: the tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) + if self.runtime is None: + raise ValueError("runtime is required") + + return ModelInvocationUtils.calculate_tokens( + tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + ) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() @@ -120,16 +131,16 @@ class BuiltinTool(Tool): # merge lines into messages with max tokens messages: list[str] = [] - for i in new_lines: + for j in new_lines: if len(messages) == 0: - messages.append(i) + messages.append(j) else: - if len(messages[-1]) + len(i) < max_tokens * 0.5: - messages[-1] += i - if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: - messages.append(i) + if len(messages[-1]) + len(j) < max_tokens * 0.5: + messages[-1] += j + if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7: + messages.append(j) else: - messages[-1] += i + messages[-1] += j summaries = [] for i in range(len(messages)): diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 457954c961..7efd7e00b8 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -130,7 +130,7 @@ class ApiToolProviderController(ToolProviderController): runtime=ToolRuntime(tenant_id=self.tenant_id), ) - def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]): """ load bundled tools @@ -151,6 +151,8 @@ class ApiToolProviderController(ToolProviderController): """ if len(self.tools) > 0: return self.tools + if self.identity is None: + return None tools: list[ApiTool] = [] @@ -170,7 +172,7 @@ class ApiToolProviderController(ToolProviderController): self.tools = tools return tools - def get_tool(self, tool_name: str) -> ApiTool: + def get_tool(self, tool_name: str): """ get tool by name diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 80674de798..7fa562dc69 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -40,6 +40,8 @@ class ApiTool(Tool): :param meta: the meta data of a tool call processing, tenant_id is required :return: the new tool """ + if self.api_bundle is None: + raise ValueError("api_bundle is required") return self.__class__( entity=self.entity, api_bundle=self.api_bundle.model_copy(), @@ -67,10 +69,12 @@ class ApiTool(Tool): return ToolProviderType.API def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: - if self.runtime == None: + if self.runtime is None: raise ToolProviderCredentialValidationError("runtime not initialized") headers = {} + if self.runtime is None: + raise ValueError("runtime is required") credentials = self.runtime.credentials or {} if "auth_type" not in credentials: @@ -121,9 +125,9 @@ class ApiTool(Tool): response = response.json() try: return json.dumps(response, ensure_ascii=False) - except Exception as e: + except Exception: return json.dumps(response) - except Exception as e: + except Exception: return response.text else: raise ValueError(f"Invalid response type {type(response)}") @@ -147,7 +151,8 @@ class ApiTool(Tool): params = {} path_params = {} - body = {} + # FIXME: body should be a dict[str, Any] but it changed a lot in this function + body: Any = {} cookies = {} files = [] @@ -208,7 +213,7 @@ class ApiTool(Tool): body = body if method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, method)( + response: httpx.Response = getattr(ssrf_proxy, method)( url, params=params, headers=headers, @@ -291,7 +296,7 @@ class ApiTool(Tool): raise ValueError(f"Invalid type {property['type']} for property {property}") elif "anyOf" in property and isinstance(property["anyOf"], list): return self._convert_body_property_any_of(property, value, property["anyOf"]) - except ValueError as e: + except ValueError: return value def _invoke( @@ -305,6 +310,7 @@ class ApiTool(Tool): """ invoke http request """ + response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0556384bd2..cfa8e6b8b2 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -77,6 +77,8 @@ class ToolEngine: raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool + if tool.identity is None: + raise ValueError("tool identity is not set") try: # hit the callback handler agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) @@ -205,6 +207,8 @@ class ToolEngine: """ Invoke the tool with the given arguments. """ + if tool.identity is None: + raise ValueError("tool identity is not set") started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, @@ -250,7 +254,7 @@ class ToolEngine: text = json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False) result += f"tool response: {text}." else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message!r}." return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index bba4be6772..99f84ca274 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -8,6 +8,8 @@ from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 +import httpx + from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db @@ -96,9 +98,8 @@ class ToolFileManager: response = ssrf_proxy.get(file_url) response.raise_for_status() blob = response.content - except Exception as e: - logger.exception(f"Failed to download file from {file_url}") - raise + except httpx.TimeoutException: + raise ValueError(f"timeout when downloading file from {file_url}") mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" @@ -217,6 +218,6 @@ class ToolFileManager: # init tool_file_parser -from core.file.tool_file_parser import tool_file_manager +from core.file.tool_file_parser import tool_file_manager # noqa: E402 tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index f9a126bf9b..4787d7d79c 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -93,7 +93,7 @@ class ToolLabelManager: db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) - tool_labels = {label.tool_id: [] for label in labels} + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7e67c06873..0425ba7918 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,16 +4,18 @@ import mimetypes from collections.abc import Generator from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from yarl import URL import contexts from core.plugin.entities.plugin import GenericProviderID from core.plugin.manager.tool import PluginToolManager +from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity @@ -39,7 +41,7 @@ from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderType, ) -from core.tools.errors import ToolProviderNotFoundError +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( ProviderConfigEncrypter, @@ -57,7 +59,7 @@ class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers = {} _builtin_providers_loaded = False - _builtin_tools_labels = {} + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -140,6 +142,8 @@ class ToolManager: """ provider_controller = cls.get_builtin_provider(provider, tenant_id) tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") return tool @@ -266,6 +270,11 @@ class ToolManager: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: Optional[list[Tool]] = controller.get_tools( + user_id="", tenant_id=workflow_provider.tenant_id + ) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return cast( WorkflowTool, @@ -333,6 +342,8 @@ class ToolManager: identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -583,9 +594,11 @@ class ToolManager: # append builtin providers for provider in builtin_providers: # handle include, exclude + if provider.identity is None: + continue if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), data=provider, name_func=lambda x: x.identity.name, ): @@ -609,7 +622,7 @@ class ToolManager: db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - api_provider_controllers = [ + api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} for provider in db_api_providers ] @@ -632,7 +645,7 @@ class ToolManager: db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() ) - workflow_provider_controllers = [] + workflow_provider_controllers: list[WorkflowToolProviderController] = [] for provider in workflow_providers: try: workflow_provider_controllers.append( @@ -642,7 +655,9 @@ class ToolManager: # app has been deleted pass - labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) for provider_controller in workflow_provider_controllers: user_provider = ToolTransformService.workflow_provider_to_user_provider( @@ -693,7 +708,7 @@ class ToolManager: get tool provider """ provider_name = provider - provider_obj: ApiToolProvider = ( + provider_obj: ApiToolProvider | None = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -707,7 +722,7 @@ class ToolManager: try: credentials = json.loads(provider_obj.credentials_str) or {} - except: + except Exception: credentials = {} # package tool provider controller @@ -728,7 +743,7 @@ class ToolManager: try: icon = json.loads(provider_obj.icon) - except: + except Exception: icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels @@ -783,7 +798,7 @@ class ToolManager: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return json.loads(workflow_provider.icon) - except: + except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod @@ -799,7 +814,7 @@ class ToolManager: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") return json.loads(api_provider.icon) - except: + except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod @@ -824,7 +839,7 @@ class ToolManager: if isinstance(provider, PluginToolProviderController): try: return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except: + except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} return cls.generate_builtin_tool_icon_url(provider_id) elif provider_type == ToolProviderType.API: @@ -836,7 +851,7 @@ class ToolManager: if isinstance(provider, PluginToolProviderController): try: return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except: + except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} raise ValueError(f"plugin provider {provider_id} not found") else: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index a1014d2266..581fb622ac 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -101,7 +101,7 @@ class ProviderConfigEncrypter(BaseModel): continue data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except: + except Exception: pass cache.set(data) @@ -221,6 +221,9 @@ class ToolParameterConfigurationManager: return a deep copy of parameters with decrypted values """ + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type.value}.{self.provider_name}", @@ -245,7 +248,7 @@ class ToolParameterConfigurationManager: try: has_secret_input = True parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) - except: + except Exception: pass if has_secret_input: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index d3a8752d9a..08b6e89806 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,4 +1,5 @@ import threading +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -44,7 +46,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): def _run(self, query: str) -> str: threads = [] - all_documents = [] + all_documents: list[RagDocument] = [] for dataset_id in self.dataset_ids: retrieval_thread = threading.Thread( target=self._retriever, @@ -77,8 +79,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} for item in all_documents: - assert item.metadata - if item.metadata.get("score"): + if item.metadata and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] @@ -87,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), DocumentSegment.status == "completed", - DocumentSegment.enabled == True, + DocumentSegment.enabled is True, DocumentSegment.index_node_id.in_(index_node_ids), ).all() @@ -108,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, + Document.enabled is True, + Document.archived is False, ).first() if dataset and document: source = { @@ -140,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) + return "" raise RuntimeError("not segments found") diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index dad8c77357..a4d2de3b1c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Optional -from msal_extensions.persistence import ABC +from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 987f94a350..b382016473 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) # deal with external documents context_list = [] for position, item in enumerate(results, start=1): - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join([item.page_content for item in results])) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) + reranking_model=retrieval_model.get("reranking_model") if retrieval_model["reranking_enable"] else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + weights=retrieval_model.get("weights"), ) else: documents = [] @@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get("score"): + if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in documents] @@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): context_list = [] resource_number = 1 for segment in sorted_segments: - context = {} - document = Document.query.filter( + document_segment = Document.query.filter( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, ).first() - if dataset and document: + if not document_segment: + continue + if dataset and document_segment: source = { "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, + "document_id": document_segment.id, + "document_name": document_segment.name, + "data_source_type": document_segment.data_source_type, "segment_id": segment.id, "retriever_from": self.retriever_from, "score": document_score_list.get(segment.index_node_id, None), diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index 136491005c..ca03b1dc94 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -94,6 +94,7 @@ class DatasetRetrieverTool(Tool): llm_description="Query for the dataset to be used to retrieve the dataset.", required=True, default="", + placeholder=I18nObject(en_US="", zh_Hans=""), ), ] @@ -112,7 +113,9 @@ class DatasetRetrieverTool(Tool): result = self.retrieval_tool._run(query=query) yield self.create_text_message(text=result) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials for dataset retriever tool """ diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index b3c3292f5d..245470ea49 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models. """ import json -from typing import cast +from typing import Optional, cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ class ModelInvocationUtils: if not schema: raise InvokeModelError("No model schema found") - max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 @@ -133,14 +133,17 @@ class ModelInvocationUtils: db.session.commit() try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3e53ca6223..43521e9c22 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -5,7 +5,7 @@ from json import loads as json_loads from json.decoder import JSONDecodeError from requests import get -from yaml import YAMLError, safe_load +from yaml import YAMLError, safe_load # type: ignore from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -63,6 +63,9 @@ class ApiBasedToolSchemaParser: default=parameter["schema"]["default"] if "schema" in parameter and "default" in parameter["schema"] else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -107,6 +110,9 @@ class ApiBasedToolSchemaParser: form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), default=property.get("default", None), + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -157,9 +163,9 @@ class ApiBasedToolSchemaParser: return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: parameter = parameter or {} - typ = None + typ: Optional[str] = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE @@ -174,6 +180,8 @@ class ApiBasedToolSchemaParser: return ToolParameter.ToolParameterType.BOOLEAN elif typ == "string": return ToolParameter.ToolParameterType.STRING + else: + return None @staticmethod def parse_openapi_yaml_to_tool_bundle( @@ -236,7 +244,8 @@ class ApiBasedToolSchemaParser: if ("summary" not in operation or len(operation["summary"]) == 0) and ( "description" not in operation or len(operation["description"]) == 0 ): - warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." openapi["paths"][path][method] = { "operationId": operation["operationId"], diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 3aae31e93a..d42fd99fce 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -9,13 +9,13 @@ import tempfile import unicodedata from contextlib import contextmanager from pathlib import Path -from typing import Optional +from typing import Any, Literal, Optional, cast from urllib.parse import unquote import chardet -import cloudscraper -from bs4 import BeautifulSoup, CData, Comment, NavigableString -from regex import regex +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: return "Unsupported content-type [{}] of URL.".format(main_content_type) if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: @@ -125,7 +125,7 @@ def extract_using_readabilipy(html): os.unlink(article_json_path) os.unlink(html_path) - article_json = { + article_json: dict[str, Any] = { "title": None, "byline": None, "date": None, @@ -300,7 +300,7 @@ def strip_control_characters(text): def normalize_unicode(text): """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form = "NFKC" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" text = unicodedata.normalize(normal_form, text) return text @@ -332,6 +332,7 @@ def add_content_digest(element): def content_digest(element): + digest: Any if is_text(element): # Hash trimmed_string = element.string.strip() diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 5cdf85e583..7e35dc7514 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: Mapping[str, Any]): + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): for configuration in configurations: WorkflowToolParameterConfiguration.model_validate(configuration) @@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils: @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] - ) -> None: + ) -> bool: """ check is synced diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 42c7f85bc6..ee7ca11e05 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -2,7 +2,7 @@ import logging from pathlib import Path from typing import Any -import yaml +import yaml # type: ignore from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index c40ea0a0b0..e6b61a88af 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import ( ToolProviderIdentity, ToolProviderType, ) +from core.tools.tool.tool import Tool from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -130,6 +131,7 @@ class WorkflowToolProviderController(ToolProviderController): llm_description=parameter.description, required=variable.required, options=options, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif features.file_upload: @@ -142,6 +144,7 @@ class WorkflowToolProviderController(ToolProviderController): llm_description=parameter.description, required=False, form=parameter.form, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) else: @@ -198,6 +201,8 @@ class WorkflowToolProviderController(ToolProviderController): if not db_providers: return [] + if not db_providers.app: + raise ValueError("app not found") app = db_providers.app if not app: @@ -207,7 +212,7 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ get tool by name diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 2998fb8ce2..b9238d11a0 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -102,7 +102,7 @@ class WorkflowTool(Tool): raise Exception(data.get("error")) outputs = data.get("outputs") - if outputs == None: + if outputs is None: outputs = {} else: outputs, files = self._extract_files(outputs) diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 2b1a58f93a..7a1cbf9940 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -21,6 +21,7 @@ from .variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, + ArrayVariable, FileVariable, FloatVariable, IntegerVariable, @@ -43,6 +44,7 @@ __all__ = [ "ArraySegment", "ArrayStringSegment", "ArrayStringVariable", + "ArrayVariable", "FileSegment", "FileVariable", "FloatSegment", diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index c902303eef..c32815b24d 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import cast from uuid import uuid4 from pydantic import Field @@ -10,6 +11,7 @@ from .segments import ( ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, + ArraySegment, ArrayStringSegment, FileSegment, FloatSegment, @@ -52,19 +54,23 @@ class ObjectVariable(ObjectSegment, Variable): pass -class ArrayAnyVariable(ArrayAnySegment, Variable): +class ArrayVariable(ArraySegment, Variable): pass -class ArrayStringVariable(ArrayStringSegment, Variable): +class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): pass -class ArrayNumberVariable(ArrayNumberSegment, Variable): +class ArrayStringVariable(ArrayStringSegment, ArrayVariable): pass -class ArrayObjectVariable(ArrayObjectSegment, Variable): +class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): + pass + + +class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): pass @@ -73,7 +79,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return encrypter.obfuscated_token(self.value) + return cast(str, encrypter.obfuscated_token(self.value)) class NoneVariable(NoneSegment, Variable): @@ -85,5 +91,5 @@ class FileVariable(FileSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): +class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index ed737e7316..b9c6b35ad3 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -33,7 +33,7 @@ _TEXT_COLOR_MAPPING = { class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: - self.current_node_id = None + self.current_node_id: Optional[str] = None def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index e5449aacfb..ed26889614 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -37,12 +37,15 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict[str, Any]] = None # process data + process_data: Optional[Mapping[str, Any]] = None # process data outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed error_type: Optional[str] = None # error type if status is failed + + # single step node run retry + retry_index: int = 0 diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index bc3a15bd00..b8470aecbd 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,7 +5,7 @@ from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState): """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 08cd2fe463..dca1021163 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent): error: str = Field(..., description="failed reason") - exceptions_count: Optional[int] = Field(description="exception count", default=0) + exceptions_count: int = Field(description="exception count", default=0) class GraphRunPartialSucceededEvent(BaseGraphEvent): @@ -97,6 +97,12 @@ class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunRetryEvent(NodeRunStartedEvent): + error: str = Field(..., description="error") + retry_index: int = Field(..., description="which retry attempt is about to be performed") + start_at: datetime = Field(..., description="retry start time") + + ########################################### # Parallel Branch Events ########################################### diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 4f7bc60e26..b3bcc3b2cc 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,9 +1,11 @@ import uuid +from collections import defaultdict from collections.abc import Mapping from typing import Any, Optional, cast from pydantic import BaseModel, Field +from configs import dify_config from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter @@ -170,7 +172,9 @@ class Graph(BaseModel): for parallel in parallel_mapping.values(): if parallel.parent_parallel_id: cls._check_exceed_parallel_limit( - parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id + parallel_mapping=parallel_mapping, + level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + parent_parallel_id=parallel.parent_parallel_id, ) # init answer stream generate routes @@ -307,26 +311,17 @@ class Graph(BaseModel): parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = {} - condition_edge_mappings = {} + parallel_branch_node_ids = defaultdict(list) + condition_edge_mappings = defaultdict(list) for graph_edge in target_node_edges: if graph_edge.run_condition is None: - if "default" not in parallel_branch_node_ids: - parallel_branch_node_ids["default"] = [] - parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if condition_hash not in condition_edge_mappings: - condition_edge_mappings[condition_hash] = [] - condition_edge_mappings[condition_hash].append(graph_edge) for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: - if condition_hash not in parallel_branch_node_ids: - parallel_branch_node_ids[condition_hash] = [] - for graph_edge in graph_edges: parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) @@ -415,7 +410,7 @@ class Graph(BaseModel): if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): for graph_edge in graph_edges: - current_parallel: GraphParallel | None = cls._get_current_parallel( + current_parallel = cls._get_current_parallel( parallel_mapping=parallel_mapping, graph_edge=graph_edge, parallel=condition_parallels.get(condition_hash), diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 3d1a42374e..463525b9f4 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,6 +6,7 @@ import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy +from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app @@ -26,6 +27,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -39,6 +41,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor @@ -65,7 +68,7 @@ class GraphEngineThreadPool(ThreadPoolExecutor): self.max_submit_count = max_submit_count self.submit_count = 0 - def submit(self, fn, *args, **kwargs): + def submit(self, fn, /, *args, **kwargs): self.submit_count += 1 self.check_is_full() @@ -139,7 +142,8 @@ class GraphEngine: def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() - handle_exceptions = [] + handle_exceptions: list[str] = [] + stream_processor: StreamProcessor try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -349,7 +353,7 @@ class GraphEngine: if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results - condition_edge_mappings = {} + condition_edge_mappings: dict[str, list[GraphEdge]] = {} for edge in edge_mappings: if edge.run_condition: run_condition_hash = edge.run_condition.hash @@ -363,6 +367,9 @@ class GraphEngine: continue edge = cast(GraphEdge, sub_edge_mappings[0]) + if edge.run_condition is None: + logger.warning(f"Edge {edge.target_node_id} run condition is None") + continue result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -386,11 +393,11 @@ class GraphEngine: handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for parallel_result in parallel_generator: + if isinstance(parallel_result, str): + final_node_id = parallel_result else: - yield item + yield parallel_result break @@ -412,11 +419,11 @@ class GraphEngine: handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for generated_item in parallel_generator: + if isinstance(generated_item, str): + final_node_id = generated_item else: - yield item + yield generated_item if not final_node_id: break @@ -587,7 +594,7 @@ class GraphEngine: def _run_node( self, - node_instance: BaseNode, + node_instance: BaseNode[BaseNodeData], route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -613,36 +620,120 @@ class GraphEngine: ) db.session.close() + max_retries = node_instance.node_data.retry_config.max_retries + retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + retries = 0 + should_continue_retry = True + while should_continue_retry and retries <= max_retries: + try: + # run node + retry_start_at = datetime.now(UTC).replace(tzinfo=None) + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id - try: - # run node - generator = node_instance.run() - for item in generator: - if isinstance(item, GraphEngineEvent): - if isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if ( + retries == max_retries + and node_instance.node_type == NodeType.HTTP_REQUEST + and run_result.outputs + and not node_instance.should_continue_on_error + ): + run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + if node_instance.should_retry and retries < max_retries: + retries += 1 + route_node_state.node_run_result = run_result + yield NodeRunRetryEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + predecessor_node_id=node_instance.previous_node_id, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=run_result.error or "Unknown error", + retry_index=retries, + start_at=retry_start_at, + ) + time.sleep(retry_interval) + continue + route_node_state.set_finished(run_result=run_result) - yield item - else: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result - route_node_state.set_finished(run_result=run_result) + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if node_instance.should_continue_on_error: + # if run failed, handle error + run_result = self._handle_continue_on_error( + node_instance, + item.run_result, + self.graph_runtime_state.variable_pool, + handle_exceptions=handle_exceptions, + ) + route_node_state.node_run_result = run_result + route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + yield NodeRunExceptionEvent( + error=run_result.error or "System Error", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + should_continue_retry = False + else: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + should_continue_retry = False + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if node_instance.should_continue_on_error and self.graph.edge_mapping.get( + node_instance.node_id + ): + run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) - if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.should_continue_on_error: - # if run failed, handle error - run_result = self._handle_continue_on_error( - node_instance, - item.run_result, - self.graph_runtime_state.variable_pool, - handle_exceptions=handle_exceptions, - ) - route_node_state.node_run_result = run_result - route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively @@ -651,133 +742,86 @@ class GraphEngine: variable_key_list=[variable_key], variable_value=variable_value, ) - yield NodeRunExceptionEvent( - error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - else: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - if node_instance.should_continue_on_error and self.graph.edge_mapping.get( - node_instance.node_id - ): - run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS - if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - # plus state total_tokens - self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] - ) - - if run_result.llm_usage: - # use the latest usage - self.graph_runtime_state.llm_usage += run_result.llm_usage - - # append node output variables to variable pool - if run_result.outputs: - for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node_instance.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, - ) - - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: + # When setting metadata, convert to dict first if not run_result.metadata: run_result.metadata = {} - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id - if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( - parent_parallel_start_node_id - ) + if parallel_id and parallel_start_node_id: + metadata_dict = dict(run_result.metadata) + metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) + run_result.metadata = metadata_dict - yield NodeRunSucceededEvent( + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + should_continue_retry = False + + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - - break - elif isinstance(item, RunStreamChunkEvent): - yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - elif isinstance(item, RunRetrieverResourceEvent): - yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - except GenerateTaskStoppedError: - # trigger node run failed event - route_node_state.status = RouteNodeState.Status.FAILED - route_node_state.failed_reason = "Workflow stopped." - yield NodeRunFailedEvent( - error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - return - except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") - raise e - finally: - db.session.close() + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + except GenerateTaskStoppedError: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed") + raise e + finally: + db.session.close() def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ @@ -836,8 +880,8 @@ class GraphEngine: variable_pool.add([node_instance.node_id, "error_message"], error_result.error) variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions - handle_exceptions.append(error_result.error) - node_error_args = { + handle_exceptions.append(error_result.error or "") + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": error_result.error, "inputs": error_result.inputs, diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 1b948bf592..7d652d39f7 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -147,6 +147,8 @@ class AnswerStreamGeneratorRouter: reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id + if source_node_id not in node_id_config_mapping: + continue source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") source_node_data = node_id_config_mapping[source_node_id].get("data", {}) if ( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index d94f059058..40213bd151 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -60,11 +60,10 @@ class AnswerStreamProcessor(StreamProcessor): del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] - # remove unreachable nodes self._remove_unreachable_nodes(event) # generate stream outputs - yield from self._generate_stream_outputs_when_node_finished(event) + yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event)) else: yield event @@ -131,7 +130,7 @@ class AnswerStreamProcessor(StreamProcessor): node_type=event.node_type, node_data=event.node_data, chunk_content=text, - from_variable_selector=value_selector, + from_variable_selector=list(value_selector), route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 36c3fe180a..8ffb487ec1 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,10 +1,13 @@ +import logging from abc import ABC, abstractmethod from collections.abc import Generator from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph +logger = logging.getLogger(__name__) + class StreamProcessor(ABC): def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: @@ -16,7 +19,7 @@ class StreamProcessor(ABC): def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError - def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -29,15 +32,24 @@ class StreamProcessor(ABC): return if run_result.edge_source_handle: - reachable_node_ids = [] - unreachable_first_node_ids = [] + reachable_node_ids: list[str] = [] + unreachable_first_node_ids: list[str] = [] + if finished_node_id not in self.graph.edge_mapping: + logger.warning(f"node {finished_node_id} has no edge mapping") + return for edge in self.graph.edge_mapping[finished_node_id]: if ( edge.run_condition and edge.run_condition.branch_identify and run_result.edge_source_handle == edge.run_condition.branch_identify ): - reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + # remove unreachable nodes + # FIXME: because of the code branch can combine directly, so for answer node + # we remove the node maybe shortcut the answer node, so comment this code for now + # there is not effect on the answer node and the workflow, when we have a better solution + # we can open this code. Issues: #11542 #9560 #10638 #10564 + + # reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) continue else: unreachable_first_node_ids.append(edge.target_node_id) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 9271867aff..6bf8899f5d 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -38,7 +38,8 @@ class DefaultValue(BaseModel): @staticmethod def _validate_array(value: Any, element_type: DefaultValueType) -> bool: """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore @staticmethod def _convert_number(value: str) -> float: @@ -84,7 +85,7 @@ class DefaultValue(BaseModel): }, } - validator = type_validators.get(self.type) + validator: dict[str, Any] = type_validators.get(self.type, {}) if not validator: if self.type == DefaultValueType.ARRAY_FILES: # Handle files type @@ -106,12 +107,25 @@ class DefaultValue(BaseModel): return self +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds + retry_enabled: bool = False # whether retry is enabled + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None version: str = "1" + retry_config: RetryConfig = RetryConfig() @property def default_value_dict(self): diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py index ec134e031c..aeecf40640 100644 --- a/api/core/workflow/nodes/base/exc.py +++ b/api/core/workflow/nodes/base/exc.py @@ -1,4 +1,4 @@ -class BaseNodeError(Exception): +class BaseNodeError(ValueError): """Base class for node errors.""" pass diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index e1e28af60b..b799e74266 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from models.workflow import WorkflowNodeExecutionStatus @@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]): result = self._run() except Exception as e: logger.exception(f"Node {self.node_id} failed to run") - result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError") + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + error_type="WorkflowNodeError", + ) if isinstance(result, NodeRunResult): yield RunCompletedEvent(run_result=result) @@ -143,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]): bool: if should continue on error """ return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + + @property + def should_retry(self) -> bool: + """judge if should retry + + Returns: + bool: if should retry + """ + return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 19b9078a5c..2f82bf8c38 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Optional from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result, self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _check_string(self, value: str, variable: str) -> str: + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, str): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: raise OutputValidationError( @@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]): return value.replace("\x00", "") - def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + def _check_number(self, value: int | float | None, variable: str) -> int | float | None: """ Check number :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, int | float): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: raise OutputValidationError( @@ -118,18 +116,16 @@ class CodeNode(BaseNode[CodeNodeData]): return value def _transform_result( - self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 - ) -> dict: - """ - Transform result - :param result: result - :param output_schema: output schema - :return: - """ + self, + result: Mapping[str, Any], + output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = "", + depth: int = 1, + ): if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") - transformed_result = {} + transformed_result: dict[str, Any] = {} if output_schema is None: # validate output thought instance type for output_name, output_value in result.items(): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index e78183baf1..a454035888 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] - children: Optional[dict[str, "Output"]] = None + children: Optional[dict[str, "CodeNodeData.Output"]] = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 59afe7ac87..0b1dc611c5 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -1,19 +1,15 @@ import csv import io import json +import logging import os import tempfile +from typing import cast import docx import pandas as pd import pypdfium2 # type: ignore import yaml # type: ignore -from unstructured.partition.api import partition_via_api -from unstructured.partition.email import partition_email -from unstructured.partition.epub import partition_epub -from unstructured.partition.msg import partition_msg -from unstructured.partition.ppt import partition_ppt -from unstructured.partition.pptx import partition_pptx from configs import dify_config from core.file import File, FileTransferMethod, file_manager @@ -28,6 +24,8 @@ from models.workflow import WorkflowNodeExecutionStatus from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError +logger = logging.getLogger(__name__) + class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): """ @@ -162,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) except (UnicodeDecodeError, yaml.YAMLError) as e: raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e @@ -183,10 +181,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str: def _extract_text_from_doc(file_content: bytes) -> str: + """ + Extract text from a DOC/DOCX file. + For now support only paragraph and table add more if needed + """ try: doc_file = io.BytesIO(file_content) doc = docx.Document(doc_file) - return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + text = [] + # Process paragraphs + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text.append(paragraph.text) + + # Process tables + for table in doc.tables: + # Table header + try: + # table maybe cause errors so ignore it. + if len(table.rows) > 0 and table.rows[0].cells is not None: + # Check if any cell in the table has text + has_content = False + for row in table.rows: + if any(cell.text.strip() for cell in row.cells): + has_content = True + break + + if has_content: + markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" + for row in table.rows[1:]: + markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" + text.append(markdown_table) + except Exception as e: + logger.warning(f"Failed to extract table from DOC/DOCX: {e}") + continue + + return "\n".join(text) except Exception as e: raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e @@ -199,9 +230,9 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError("Missing URL for remote file") response = ssrf_proxy.get(file.remote_url) response.raise_for_status() - return response.content + return cast(bytes, response.content) else: - return file_manager.download(file) + return cast(bytes, file_manager.download(file)) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e @@ -256,6 +287,8 @@ def _extract_text_from_excel(file_content: bytes) -> str: def _extract_text_from_ppt(file_content: bytes) -> str: + from unstructured.partition.ppt import partition_ppt + try: with io.BytesIO(file_content) as file: elements = partition_ppt(file=file) @@ -265,6 +298,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str: + from unstructured.partition.api import partition_via_api + from unstructured.partition.pptx import partition_pptx + try: if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: @@ -287,6 +323,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str: def _extract_text_from_epub(file_content: bytes) -> str: + from unstructured.partition.epub import partition_epub + try: with io.BytesIO(file_content) as file: elements = partition_epub(file=file) @@ -296,6 +334,8 @@ def _extract_text_from_epub(file_content: bytes) -> str: def _extract_text_from_eml(file_content: bytes) -> str: + from unstructured.partition.email import partition_email + try: with io.BytesIO(file_content) as file: elements = partition_email(file=file) @@ -305,6 +345,8 @@ def _extract_text_from_eml(file_content: bytes) -> str: def _extract_text_from_msg(file_content: bytes) -> str: + from unstructured.partition.msg import partition_msg + try: with io.BytesIO(file_content) as file: elements = partition_msg(file=file) diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index ea8b6b5042..b3678a82b7 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -67,7 +67,7 @@ class EndStreamGeneratorRouter: and node_type == NodeType.LLM.value and variable_selector.value_selector[1] == "text" ): - value_selectors.append(variable_selector.value_selector) + value_selectors.append(list(variable_selector.value_selector)) return value_selectors @@ -119,8 +119,7 @@ class EndStreamGeneratorRouter: current_node_id: str, end_node_id: str, node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] end_dependencies: dict[str, list[str]], ) -> None: """ @@ -135,6 +134,8 @@ class EndStreamGeneratorRouter: reverse_edges = reverse_edge_mapping.get(current_node_id, []) for edge in reverse_edges: source_node_id = edge.source_node_id + if source_node_id not in node_id_config_mapping: + continue source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in { NodeType.IF_ELSE.value, diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 1aecf863ac..a770eb951f 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -23,7 +23,7 @@ class EndStreamProcessor(StreamProcessor): self.route_position[end_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.has_output = False - self.output_node_ids = set() + self.output_node_ids: set[str] = set() def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index a077c33278..25e049577a 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -36,3 +36,4 @@ class FailBranchSourceHandle(StrEnum): CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] +RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py index 5e3b31e48b..08c47d5e57 100644 --- a/api/core/workflow/nodes/event/__init__.py +++ b/api/core/workflow/nodes/event/__init__.py @@ -1,4 +1,10 @@ -from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .event import ( + ModelInvokeCompletedEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunRetryEvent, + RunStreamChunkEvent, +) from .types import NodeEvent __all__ = [ @@ -6,5 +12,6 @@ __all__ = [ "NodeEvent", "RunCompletedEvent", "RunRetrieverResourceEvent", + "RunRetryEvent", "RunStreamChunkEvent", ] diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index b7034561bf..9fea3fbda3 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,7 +1,10 @@ +from datetime import datetime + from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus class RunCompletedEvent(BaseModel): @@ -26,3 +29,19 @@ class ModelInvokeCompletedEvent(BaseModel): text: str usage: LLMUsage finish_reason: str | None = None + + +class RunRetryEvent(BaseModel): + """Node Run Retry event""" + + error: str = Field(..., description="error") + retry_index: int = Field(..., description="Retry attempt number") + start_at: datetime = Field(..., description="Retry start time") + + +class SingleStepRetryEvent(NodeRunResult): + """Single step retry event""" + + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY + + elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py index 7a5ab7dbc1..a815f277be 100644 --- a/api/core/workflow/nodes/http_request/exc.py +++ b/api/core/workflow/nodes/http_request/exc.py @@ -16,3 +16,7 @@ class InvalidHttpMethodError(HttpRequestNodeError): class ResponseSizeError(HttpRequestNodeError): """Raised when the response size exceeds the allowed threshold.""" + + +class RequestBodyError(HttpRequestNodeError): + """Raised when the request body is invalid.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 90251c27a8..cdfdc6e6d5 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -23,6 +23,7 @@ from .exc import ( FileFetchError, HttpRequestNodeError, InvalidHttpMethodError, + RequestBodyError, ResponseSizeError, ) @@ -45,6 +46,7 @@ class Executor: headers: dict[str, str] auth: HttpRequestNodeAuthorization timeout: HttpRequestNodeTimeout + max_retries: int boundary: str @@ -54,6 +56,7 @@ class Executor: node_data: HttpRequestNodeData, timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, + max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -73,6 +76,7 @@ class Executor: self.files = None self.data = None self.json = None + self.max_retries = max_retries # init template self.variable_pool = variable_pool @@ -103,9 +107,9 @@ class Executor: if not (key := key.strip()): continue - value = value[0].strip() if value else "" + value_str = value[0].strip() if value else "" result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text) + (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) self.params = result @@ -140,13 +144,19 @@ class Executor: case "none": self.content = "" case "raw-text": + if len(data) != 1: + raise RequestBodyError("raw-text body type should have exactly one item") self.content = self.variable_pool.convert_template(data[0].value).text case "json": + if len(data) != 1: + raise RequestBodyError("json body type should have exactly one item") json_string = self.variable_pool.convert_template(data[0].value).text json_object = json.loads(json_string, strict=False) self.json = json_object # self.json = self._parse_object_contains_variables(json_object) case "binary": + if len(data) != 1: + raise RequestBodyError("binary body type should have exactly one item") file_selector = data[0].file file_variable = self.variable_pool.get_file(file_selector) if file_variable is None: @@ -172,9 +182,10 @@ class Executor: self.variable_pool.convert_template(item.key).text: item.file for item in filter(lambda item: item.type == "file", data) } + files: dict[str, Any] = {} files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} files = {k: v for k, v in files.items() if v is not None} - files = {k: variable.value for k, variable in files.items()} + files = {k: variable.value for k, variable in files.items() if variable is not None} files = { k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") for k, v in files.items() @@ -241,13 +252,15 @@ class Executor: "params": self.params, "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "follow_redirects": True, + "max_retries": self.max_retries, } # request_args = {k: v for k, v in request_args.items() if v is not None} try: response = getattr(ssrf_proxy, self.method)(**request_args) - except ssrf_proxy.MaxRetriesExceededError as e: + except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: raise HttpRequestNodeError(str(e)) - return response + # FIXME: fix type ignore, this maybe httpx type issue + return response # type: ignore def invoke(self) -> Response: # assemble headers @@ -289,35 +302,37 @@ class Executor: continue raw += f"{k}: {v}\r\n" - body = "" + body_string = "" if self.files: for k, v in self.files.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' - body += f"{v[1]}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body_string += f"{v[1]}\r\n" + body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: if isinstance(self.content, str): - body = self.content + body_string = self.content elif isinstance(self.content, bytes): - body = self.content.decode("utf-8", errors="replace") + body_string = self.content.decode("utf-8", errors="replace") elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body = urlencode(self.data) + body_string = urlencode(self.data) elif self.data and self.node_data.body.type == "form-data": for key, value in self.data.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body += f"{value}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body_string += f"{value}\r\n" + body_string += f"--{boundary}--\r\n" elif self.json: - body = json.dumps(self.json) + body_string = json.dumps(self.json) elif self.node_data.body.type == "raw-text": - body = self.node_data.body.data[0].value - if body: - raw += f"Content-Length: {len(body)}\r\n" + if len(self.node_data.body.data) != 1: + raise RequestBodyError("raw-text body type should have exactly one item") + body_string = self.node_data.body.data[0].value + if body_string: + raw += f"Content-Length: {len(body_string)}\r\n" raw += "\r\n" # Empty line between headers and body - raw += body + raw += body_string return raw diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index d040cc9f55..861119f26c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,6 +1,7 @@ import logging +import mimetypes from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod @@ -19,7 +20,7 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) -from .exc import HttpRequestNodeError +from .exc import HttpRequestNodeError, RequestBodyError HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -35,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): _node_type = NodeType.HTTP_REQUEST @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: + def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { "type": "http-request", "config": { @@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, + "retry_config": { + "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES, + "retry_interval": 0.5 * (2**2), + "retry_enabled": True, + }, } def _run(self) -> NodeRunResult: @@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): node_data=self.node_data, timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, + max_retries=0, ) process_data["request"] = http_executor.to_log() response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and self.should_continue_on_error: + if not response.response.is_success and (self.should_continue_on_error or self.should_retry): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs={ @@ -129,9 +136,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): data = node_data.body.data match body_type: case "binary": + if len(data) != 1: + raise RequestBodyError("invalid body data, should have only one item") selector = data[0].file selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) case "json" | "raw-text": + if len(data) != 1: + raise RequestBodyError("invalid body data, should have only one item") selectors += variable_template_parser.extract_selectors_from_template(data[0].key) selectors += variable_template_parser.extract_selectors_from_template(data[0].value) case "x-www-form-urlencoded": @@ -149,27 +160,31 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): ) mapping = {} - for selector in selectors: - mapping[node_id + "." + selector.variable] = selector.value_selector + for selector_iter in selectors: + mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector return mapping def extract_files(self, url: str, response: Response) -> list[File]: """ - Extract files from response + Extract files from response by checking both Content-Type header and URL """ files = [] is_file = response.is_file content_type = response.content_type content = response.content - if is_file and content_type: + if is_file: + # Guess file extension from URL or Content-Type header + filename = url.split("?")[0].split("/")[-1] or "" + mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, file_binary=content, - mimetype=content_type, + mimetype=mime_type, ) mapping = { diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 1991bf6c0e..a7d0aefc6d 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast from flask import Flask, current_app from configs import dify_config -from core.variables import IntegerVariable +from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.workflow.entities.node_entities import ( NodeRunMetadataKey, NodeRunResult, @@ -76,12 +76,15 @@ class IterationNode(BaseNode[IterationNodeData]): """ Run the node. """ - iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - if not iterator_list_segment: - raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") + if not variable: + raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - if len(iterator_list_segment.value) == 0: + if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): + raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") + + if isinstance(variable, NoneVariable) or len(variable.value) == 0: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -90,7 +93,7 @@ class IterationNode(BaseNode[IterationNodeData]): ) return - iterator_list_value = iterator_list_segment.to_object() + iterator_list_value = variable.to_object() if not isinstance(iterator_list_value, list): raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") @@ -360,13 +363,16 @@ class IterationNode(BaseNode[IterationNodeData]): metadata = event.route_node_state.node_run_result.metadata if not metadata: metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - if self.node_data.is_parallel: - metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id - else: - metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + metadata = { + **metadata, + NodeRunMetadataKey.ITERATION_ID: self.node_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID + if self.node_data.is_parallel + else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id + if self.node_data.is_parallel + else iter_run_index, + } event.route_node_state.node_run_result.metadata = metadata return event diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 8c5a9b5ecb..bfd93c074d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): except KnowledgeRetrievalNodeError as e: logger.warning("Error when running knowledge retrieval node") - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + # Temporary handle all exceptions from DatasetRetrieval class here. + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] @@ -134,6 +147,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + if node_data.multiple_retrieval_config is None: + raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": if node_data.multiple_retrieval_config.reranking_model: reranking_model = { @@ -144,6 +159,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { @@ -160,18 +177,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): reranking_model = None weights = None all_documents = dataset_retrieval.multiple_retrieve( - self.app_id, - self.tenant_id, - self.user_id, - self.user_from.value, - available_datasets, - query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, + app_id=self.app_id, + tenant_id=self.tenant_id, + user_id=self.user_id, + user_from=self.user_from.value, + available_datasets=available_datasets, + query=query, + top_k=node_data.multiple_retrieval_config.top_k, + score_threshold=node_data.multiple_retrieval_config.score_threshold + if node_data.multiple_retrieval_config.score_threshold is not None + else 0.0, + reranking_mode=node_data.multiple_retrieval_config.reranking_mode, + reranking_model=reranking_model, + weights=weights, + reranking_enable=node_data.multiple_retrieval_config.reranking_enable, ) dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] @@ -192,7 +211,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} + document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: document_score_list = {} @@ -247,7 +266,9 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( - retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True + retrieval_resource_list, + key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + reverse=True, ) position = 1 for item in retrieval_resource_list: @@ -282,6 +303,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): :param node_data: node data :return: """ + if node_data.single_retrieval_config is None: + raise ValueError("single_retrieval_config is required") model_name = node_data.single_retrieval_config.model.name provider_name = node_data.single_retrieval_config.model.provider diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 79066cece4..432c57294e 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal, Union +from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_type = NodeType.LIST_OPERATOR def _run(self): - inputs = {} - process_data = {} - outputs = {} + inputs: dict[str, list] = {} + process_data: dict[str, list] = {} + outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) if variable is None: @@ -93,6 +93,8 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): def _apply_filter( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + filter_func: Callable[[Any], bool] + result: list[Any] = [] for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): @@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str raise InvalidKeyError(f"Invalid key: {key}") -def _contains(value: str): +def _contains(value: str) -> Callable[[str], bool]: return lambda x: value in x -def _startswith(value: str): +def _startswith(value: str) -> Callable[[str], bool]: return lambda x: x.startswith(value) -def _endswith(value: str): +def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str): +def _is(value: str) -> Callable[[str], bool]: return lambda x: x is value -def _in(value: str | Sequence[str]): +def _in(value: str | Sequence[str]) -> Callable[[str], bool]: return lambda x: x in value -def _eq(value: int | float): +def _eq(value: int | float) -> Callable[[int | float], bool]: return lambda x: x == value -def _ne(value: int | float): +def _ne(value: int | float) -> Callable[[int | float], bool]: return lambda x: x != value -def _lt(value: int | float): +def _lt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x < value -def _le(value: int | float): +def _le(value: int | float) -> Callable[[int | float], bool]: return lambda x: x <= value -def _gt(value: int | float): +def _gt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x > value -def _ge(value: int | float): +def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value @@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 19a66087f7..505068104c 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -50,6 +50,7 @@ class PromptConfig(BaseModel): class LLMNodeChatModelMessage(ChatModelMessage): + text: str = "" jinja2_text: Optional[str] = None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 67e62cb875..6909b30c9e 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: - node_inputs = None + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + node_inputs: Optional[dict[str, Any]] = None process_data = None try: @@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]): query = query_variable.text prompt_messages, stop = self._fetch_prompt_messages( - user_query=query, - user_files=files, + sys_query=query, + sys_files=files, context=context, memory=memory, model_config=model_config, @@ -196,7 +196,6 @@ class LLMNode(BaseNode[LLMNodeData]): error_type=type(e).__name__, ) ) - return except Exception as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -206,7 +205,6 @@ class LLMNode(BaseNode[LLMNodeData]): process_data=process_data, ) ) - return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -302,7 +300,7 @@ class LLMNode(BaseNode[LLMNodeData]): return messages def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables = {} + variables: dict[str, Any] = {} if not node_data.prompt_config: return variables @@ -319,7 +317,7 @@ class LLMNode(BaseNode[LLMNodeData]): """ # check if it's a context structure if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return input_dict["content"] + return str(input_dict["content"]) # else, parse the dict try: @@ -545,8 +543,8 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_prompt_messages( self, *, - user_query: str | None = None, - user_files: Sequence["File"], + sys_query: str | None = None, + sys_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -557,12 +555,13 @@ class LLMNode(BaseNode[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: - prompt_messages = [] + # FIXME: fix the type error cause prompt_messages is type quick a few times + prompt_messages: list[Any] = [] if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - _handle_list_messages( + self._handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -581,14 +580,14 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_messages.extend(memory_messages) # Add current query to the prompt messages - if user_query: + if sys_query: message = LLMNodeChatModelMessage( - text=user_query, + text=sys_query, role=PromptMessageRole.USER, edition_type="basic", ) prompt_messages.extend( - _handle_list_messages( + self._handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -635,24 +634,27 @@ class LLMNode(BaseNode[LLMNodeData]): raise ValueError("Invalid prompt content type") # Add current query to the prompt message - if user_query: + if sys_query: if prompt_content_type == str: - prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content elif prompt_content_type == list: for content_item in prompt_content: if content_item.type == PromptMessageContentType.TEXT: - content_item.data = user_query + "\n" + content_item.data + content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") else: raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - if vision_enabled and user_files: + # The sys_files will be deprecated later + if vision_enabled and sys_files: file_prompts = [] - for file in user_files: + for file in sys_files: file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt if ( len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage) @@ -662,7 +664,7 @@ class LLMNode(BaseNode[LLMNodeData]): else: prompt_messages.append(UserPromptMessage(content=file_prompts)) - # Filter prompt messages + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message.content, list): @@ -780,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]): else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - variable_mapping = {} + variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector @@ -846,6 +848,68 @@ class LLMNode(BaseNode[LLMNodeData]): }, } + def _handle_list_messages( + self, + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, + ) -> Sequence[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -880,68 +944,6 @@ def _render_jinja2_message( return result_text -def _handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, -) -> Sequence[PromptMessage]: - prompt_messages = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinjia2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - if isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages - - def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity ) -> int: @@ -978,7 +980,7 @@ def _handle_memory_chat_mode( memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: - memory_messages = [] + memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 6fdff96602..a366c287c2 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self) -> LoopState: - return super()._run() + def _run(self) -> LoopState: # type: ignore + return super()._run() # type: ignore @classmethod def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: @@ -28,7 +28,7 @@ class LoopNode(BaseNode[LoopNodeData]): # TODO waiting for implementation return [ - Condition( + Condition( # type: ignore variable_selector=[node_id, "index"], comparison_operator="≤", value_type="value_selector", diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index a001b44dc7..369eb13b04 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -25,7 +25,7 @@ class ParameterConfig(BaseModel): raise ValueError("Parameter name is required") if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return value + return str(value) class ParameterExtractorNodeData(BaseNodeData): @@ -52,7 +52,7 @@ class ParameterExtractorNodeData(BaseNodeData): :return: parameter json schema """ - parameters = {"type": "object", "properties": {}, "required": []} + parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 5884af0412..e147caacf3 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode): Parameter Extractor Node. """ - _node_data_cls = ParameterExtractorNodeData + # FIXME: figure out why here is different from super class + _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR _model_instance: Optional[ModelInstance] = None @@ -179,6 +180,15 @@ class ParameterExtractorNode(LLMNode): error=str(e), metadata={}, ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, + error=str(e), + metadata={}, + ) error = None @@ -244,6 +254,9 @@ class ParameterExtractorNode(LLMNode): # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + if text is None: + text = "" + return text, usage, tool_call def _generate_function_call_prompt( @@ -596,9 +609,10 @@ class ParameterExtractorNode(LLMNode): json_str = extract_json(result[idx:]) if json_str: try: - return json.loads(json_str) + return cast(dict, json.loads(json_str)) except Exception: pass + return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: """ @@ -607,13 +621,13 @@ class ParameterExtractorNode(LLMNode): if not tool_call or not tool_call.function.arguments: return None - return json.loads(tool_call.function.arguments) + return cast(dict, json.loads(tool_call.function.arguments)) def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ Generate default result. """ - result = {} + result: dict[str, Any] = {} for parameter in data.parameters: if parameter.type == "number": result[parameter.name] = 0 @@ -763,7 +777,7 @@ class ParameterExtractorNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, + node_data: ParameterExtractorNodeData, # type: ignore ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -772,6 +786,7 @@ class ParameterExtractorNode(LLMNode): :param node_data: node data :return: """ + # FIXME: fix the type error later variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index e603add170..6c3155ac9a 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,3 +1,5 @@ +from typing import Any + FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. @@ -35,7 +37,7 @@ FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information fr """ # noqa: E501 -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ +FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ { "user": { "query": "What is the weather today in SF?", diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 7594036b50..0ec44eefac 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,10 +1,8 @@ import json -import logging from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.llm_generator.output_parser.errors import OutputParserError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole @@ -36,12 +34,9 @@ from .template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_3, ) -if TYPE_CHECKING: - from core.file import File - class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData + _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER def _run(self): @@ -63,7 +58,7 @@ class QuestionClassifierNode(LLMNode): node_data.instruction = node_data.instruction or "" node_data.instruction = variable_pool.convert_template(node_data.instruction).text - files: Sequence[File] = ( + files = ( self._fetch_files( selector=node_data.vision.configs.variable_selector, ) @@ -86,37 +81,38 @@ class QuestionClassifierNode(LLMNode): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - user_query=query, + sys_query=query, memory=memory, model_config=model_config, - user_files=files, + sys_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], ) - # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - ) - result_text = "" usage = LLMUsage.empty_usage() finish_reason = None - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - category_name = node_data.classes[0].name - category_id = node_data.classes[0].id try: + # handle invoke result + generator = self._invoke_llm( + node_data_model=node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + ) + + for event in generator: + if isinstance(event, ModelInvokeCompletedEvent): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + + category_name = node_data.classes[0].name + category_id = node_data.classes[0].id result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) if "category_name" in result_text_json and "category_id" in result_text_json: @@ -127,10 +123,6 @@ class QuestionClassifierNode(LLMNode): if category_id_result in category_ids: category_name = classes_map[category_id_result] category_id = category_id_result - - except OutputParserError: - logging.exception(f"Failed to parse result text: {result_text}") - try: process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -154,7 +146,6 @@ class QuestionClassifierNode(LLMNode): }, llm_usage=usage, ) - except ValueError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -174,7 +165,7 @@ class QuestionClassifierNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: QuestionClassifierNodeData, + node_data: Any, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -183,6 +174,7 @@ class QuestionClassifierNode(LLMNode): :param node_data: node data :return: """ + node_data = cast(QuestionClassifierNodeData, node_data) variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index d8752058b0..dc919892e5 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType from core.plugin.manager.exc import PluginDaemonClientSideError from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine -from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable @@ -58,6 +57,8 @@ class ToolNode(BaseNode[ToolNodeData]): # get tool runtime try: + from core.tools.tool_manager import ToolManager + tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) @@ -145,7 +146,7 @@ class ToolNode(BaseNode[ToolNodeData]): """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.tool_parameters: parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/core/workflow/nodes/variable_assigner/common/exc.py index a1178fb020..f8dbedc290 100644 --- a/api/core/workflow/nodes/variable_assigner/common/exc.py +++ b/api/core/workflow/nodes/variable_assigner/common/exc.py @@ -1,4 +1,4 @@ -class VariableOperatorNodeError(Exception): +class VariableOperatorNodeError(ValueError): """Base error type, don't use directly.""" pass diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 8eb4bd5c2d..9acc76f326 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -36,6 +36,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) + if income_value is None: + raise VariableOperatorNodeError("income value not found") updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index d73c744202..0c4aae827c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, cast from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID @@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() - process_data = {} + process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variables: list[Variable] = [] @@ -119,7 +119,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): else: conversation_id = conversation_id.value common_helpers.update_conversation_variable( - conversation_id=conversation_id, + conversation_id=cast(str, conversation_id), variable=variable, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 0800c48e2a..06e4b35dec 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -129,11 +129,11 @@ class WorkflowEntry: :return: """ # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: + workflow_graph = workflow.graph_dict + if not workflow_graph: raise ValueError("workflow graph not found") - nodes = graph.get("nodes") + nodes = workflow_graph.get("nodes") if not nodes: raise ValueError("nodes not found in workflow graph") @@ -297,7 +297,8 @@ class WorkflowEntry: @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: - return WorkflowEntry._handle_special_values(value) + result = WorkflowEntry._handle_special_values(value) + return result if isinstance(result, Mapping) or result is None else dict(result) @staticmethod def _handle_special_values(value: Any) -> Any: @@ -309,10 +310,10 @@ class WorkflowEntry: res[k] = WorkflowEntry._handle_special_values(v) return res if isinstance(value, list): - res = [] + res_list = [] for item in value: - res.append(WorkflowEntry._handle_special_values(item)) - return res + res_list.append(WorkflowEntry._handle_special_values(item)) + return res_list if isinstance(value, File): return value.to_dict() return value diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 24fa013697..8a677f6b6f 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,7 +14,7 @@ from models.dataset import Document @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids") + document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() for document_id in document_ids: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 1515661b2d..5e7caf8cbe 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -8,18 +8,19 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender account = kwargs.get("account") - site = Site( - app_id=app.id, - title=app.name, - icon_type=app.icon_type, - icon=app.icon, - icon_background=app.icon_background, - default_language=account.interface_language, - customize_token_strategy="not_allow", - code=Site.generate_code(16), - created_by=app.created_by, - updated_by=app.updated_by, - ) + if account is not None: + site = Site( + app_id=app.id, + title=app.name, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, + default_language=account.interface_language, + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, + ) - db.session.add(site) - db.session.commit() + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index f1479c58a4..70d2c7d8c9 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -47,7 +47,7 @@ def handle(sender, **kwargs): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 9c5955c8c5..f89fae24a5 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,7 +8,10 @@ from events.app_event import app_draft_workflow_was_synced @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + synced_draft_workflow = kwargs.get("synced_draft_workflow") + if synced_draft_workflow is None: + return + for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index de7c0f4dfe..408ed31096 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -8,16 +8,18 @@ from models.model import AppModelConfig def handle(sender, **kwargs): app = sender app_model_config = kwargs.get("app_model_config") + if app_model_config is None: + return dataset_ids = get_dataset_ids_from_model_config(app_model_config) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -37,8 +39,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: - dataset_ids = set() +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]: + dataset_ids: set[int] = set() if not app_model_config: return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 453395e8d7..7a31c82f6a 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -17,11 +17,11 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_workflow(published_workflow) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -41,8 +41,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: - dataset_ids = set() +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]: + dataset_ids: set[int] = set() graph = published_workflow.graph_dict if not graph: return dataset_ids @@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: for node in knowledge_retrieval_nodes: try: node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) - dataset_ids.update(node_data.dataset_ids) + dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids) except Exception as e: continue diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/__init__.py b/api/extensions/__init__.py similarity index 100% rename from api/tests/integration_tests/model_runtime/gitee_ai/__init__.py rename to api/extensions/__init__.py diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index de1cdfeb98..b7d412d68d 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -54,12 +54,14 @@ def init_app(app: DifyApp): from extensions.ext_database import db engine = db.engine + # TODO: Fix the type error + # FIXME maybe its sqlalchemy issue return { "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, + "pool_size": engine.pool.size(), # type: ignore + "checked_in_connections": engine.pool.checkedin(), # type: ignore + "checked_out_connections": engine.pool.checkedout(), # type: ignore + "overflow_connections": engine.pool.overflow(), # type: ignore + "connection_timeout": engine.pool.timeout(), # type: ignore + "recycle_time": db.engine.pool._recycle, # type: ignore } diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 9dbc4b93d4..30f216ff95 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,8 +1,8 @@ from datetime import timedelta import pytz -from celery import Celery, Task -from celery.schedules import crontab +from celery import Celery, Task # type: ignore +from celery.schedules import crontab # type: ignore from configs import dify_config from dify_app import DifyApp @@ -47,7 +47,7 @@ def init_app(app: DifyApp) -> Celery: worker_log_format=dify_config.LOG_FORMAT, worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, - timezone=pytz.timezone(dify_config.LOG_TZ), + timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) if dify_config.BROKER_USE_SSL: diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 9c3a663af4..26ff6427be 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress + from flask_compress import Compress # type: ignore compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index e293afa111..93842a3036 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -1,18 +1,5 @@ -from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import MetaData - from dify_app import DifyApp - -POSTGRES_INDEXES_NAMING_CONVENTION = { - "ix": "%(column_0_label)s_idx", - "uq": "%(table_name)s_%(column_0_name)s_key", - "ck": "%(table_name)s_%(constraint_name)s_check", - "fk": "%(table_name)s_%(column_0_name)s_fkey", - "pk": "%(table_name)s_pkey", -} - -metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) -db = SQLAlchemy(metadata=metadata) +from models import db def init_app(app: DifyApp): diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py index eefdfd3823..9566f430b6 100644 --- a/api/extensions/ext_import_modules.py +++ b/api/extensions/ext_import_modules.py @@ -3,4 +3,3 @@ from dify_app import DifyApp def init_app(app: DifyApp): from events import event_handlers # noqa: F401 - from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 9fc29b4eb1..e1c459e8c1 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -11,7 +11,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): - log_handlers = [] + log_handlers: list[logging.Handler] = [] log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -49,7 +49,8 @@ def init_app(app: DifyApp): return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() for handler in logging.root.handlers: - handler.formatter.converter = time_converter + if handler.formatter: + handler.formatter.converter = time_converter def get_request_id(): diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index b295530714..10fb89eb73 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login +import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import Unauthorized diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 468aedd47e..9240ebe7fc 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -26,7 +26,7 @@ class Mail: match mail_type: case "resend": - import resend + import resend # type: ignore api_key = dify_config.RESEND_API_KEY if not api_key: @@ -48,9 +48,9 @@ class Mail: self._client = SMTPClient( server=dify_config.SMTP_SERVER, port=dify_config.SMTP_PORT, - username=dify_config.SMTP_USERNAME, - password=dify_config.SMTP_PASSWORD, - _from=dify_config.MAIL_DEFAULT_SEND_FROM, + username=dify_config.SMTP_USERNAME or "", + password=dify_config.SMTP_PASSWORD or "", + _from=dify_config.MAIL_DEFAULT_SEND_FROM or "", use_tls=dify_config.SMTP_USE_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 6d8f35c30d..5f862181fa 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): - import flask_migrate + import flask_migrate # type: ignore from extensions.ext_database import db diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index 3b895ac95b..514e065825 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app) + app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 8016356a3e..3a74aace6a 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,7 +6,7 @@ def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import openai import sentry_sdk - from langfuse import parse_error + from langfuse import parse_error # type: ignore from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException @@ -27,6 +27,7 @@ def init_app(app: DifyApp): ignore_errors=[ HTTPException, ValueError, + FileNotFoundError, openai.APIStatusError, InvokeRateLimitError, parse_error.defaultErrorResponse, diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 4b66f3801e..588bdb2d27 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,11 +1,10 @@ import logging -from collections.abc import Callable, Generator, Mapping -from typing import Union +from collections.abc import Callable, Generator +from typing import Literal, Union, overload from flask import Flask from configs import dify_config -from configs.middleware.storage.opendal_storage_config import OpenDALScheme from dify_app import DifyApp from extensions.storage.base_storage import BaseStorage from extensions.storage.storage_type import StorageType @@ -23,21 +22,17 @@ class Storage: def get_storage_factory(storage_type: str) -> Callable[[], BaseStorage]: match storage_type: case StorageType.S3: - from extensions.storage.opendal_storage import OpenDALStorage + from extensions.storage.aws_s3_storage import AwsS3Storage - kwargs = _load_s3_storage_kwargs() - return lambda: OpenDALStorage(scheme=OpenDALScheme.S3, **kwargs) + return AwsS3Storage case StorageType.OPENDAL: from extensions.storage.opendal_storage import OpenDALStorage - scheme = OpenDALScheme(dify_config.STORAGE_OPENDAL_SCHEME) - kwargs = _load_opendal_storage_kwargs(scheme) - return lambda: OpenDALStorage(scheme=scheme, **kwargs) + return lambda: OpenDALStorage(dify_config.OPENDAL_SCHEME) case StorageType.LOCAL: from extensions.storage.opendal_storage import OpenDALStorage - kwargs = _load_local_storage_kwargs() - return lambda: OpenDALStorage(scheme=OpenDALScheme.FS, **kwargs) + return lambda: OpenDALStorage(scheme="fs", root=dify_config.STORAGE_LOCAL_PATH) case StorageType.AZURE_BLOB: from extensions.storage.azure_blob_storage import AzureBlobStorage @@ -75,7 +70,7 @@ class Storage: return SupabaseStorage case _: - raise ValueError(f"Unsupported storage type {storage_type}") + raise ValueError(f"unsupported storage type {storage_type}") def save(self, filename, data): try: @@ -84,6 +79,12 @@ class Storage: logger.exception(f"Failed to save file {filename}") raise e + @overload + def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... + + @overload + def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: @@ -130,81 +131,6 @@ class Storage: raise e -def _load_s3_storage_kwargs() -> Mapping[str, str]: - """ - Load the kwargs for S3 storage based on dify_config. - Handles special cases like AWS managed IAM and R2. - """ - kwargs = { - "root": "/", - "bucket": dify_config.S3_BUCKET_NAME, - "endpoint": dify_config.S3_ENDPOINT, - "access_key_id": dify_config.S3_ACCESS_KEY, - "secret_access_key": dify_config.S3_SECRET_KEY, - "region": dify_config.S3_REGION, - } - kwargs = {k: v for k, v in kwargs.items() if isinstance(v, str)} - - # For AWS managed IAM - if dify_config.S3_USE_AWS_MANAGED_IAM: - from extensions.storage.opendal_storage import S3_SSE_WITH_AWS_MANAGED_IAM_KWARGS - - logger.debug("Using AWS managed IAM role for S3") - kwargs = {**kwargs, **{k: v for k, v in S3_SSE_WITH_AWS_MANAGED_IAM_KWARGS.items() if k not in kwargs}} - - # For Cloudflare R2 - if kwargs.get("endpoint"): - from extensions.storage.opendal_storage import S3_R2_COMPATIBLE_KWARGS, is_r2_endpoint - - if is_r2_endpoint(kwargs["endpoint"]): - logger.debug("Using R2 for OpenDAL S3") - kwargs = {**kwargs, **{k: v for k, v in S3_R2_COMPATIBLE_KWARGS.items() if k not in kwargs}} - - return kwargs - - -def _load_local_storage_kwargs() -> Mapping[str, str]: - """ - Load the kwargs for local storage based on dify_config. - """ - return { - "root": dify_config.STORAGE_LOCAL_PATH, - } - - -def _load_opendal_storage_kwargs(scheme: OpenDALScheme) -> Mapping[str, str]: - """ - Load the kwargs for OpenDAL storage based on the given scheme. - """ - match scheme: - case OpenDALScheme.FS: - kwargs = { - "root": dify_config.OPENDAL_FS_ROOT, - } - case OpenDALScheme.S3: - # Load OpenDAL S3-related configs - kwargs = { - "root": dify_config.OPENDAL_S3_ROOT, - "bucket": dify_config.OPENDAL_S3_BUCKET, - "endpoint": dify_config.OPENDAL_S3_ENDPOINT, - "access_key_id": dify_config.OPENDAL_S3_ACCESS_KEY_ID, - "secret_access_key": dify_config.OPENDAL_S3_SECRET_ACCESS_KEY, - "region": dify_config.OPENDAL_S3_REGION, - } - - # For Cloudflare R2 - if kwargs.get("endpoint"): - from extensions.storage.opendal_storage import S3_R2_COMPATIBLE_KWARGS, is_r2_endpoint - - if is_r2_endpoint(kwargs["endpoint"]): - logger.debug("Using R2 for OpenDAL S3") - kwargs = {**kwargs, **{k: v for k, v in S3_R2_COMPATIBLE_KWARGS.items() if k not in kwargs}} - case _: - logger.warning(f"Unrecognized OpenDAL scheme: {scheme}, will fall back to default.") - kwargs = {} - return kwargs - - storage = Storage() diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 58c917dbd3..00bf5d4f93 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 +import oss2 as aliyun_s3 # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -33,7 +33,7 @@ class AliyunOssStorage(BaseStorage): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data = obj.read() + data: bytes = obj.read() return data def load_stream(self, filename: str) -> Generator: @@ -41,14 +41,14 @@ class AliyunOssStorage(BaseStorage): while chunk := obj.read(4096): yield chunk - def download(self, filename, target_filepath): + def download(self, filename: str, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) - def exists(self, filename): + def exists(self, filename: str): return self.client.object_exists(self.__wrapper_folder_filename(filename)) - def delete(self, filename): + def delete(self, filename: str): self.client.delete_object(self.__wrapper_folder_filename(filename)) - def __wrapper_folder_filename(self, filename) -> str: + def __wrapper_folder_filename(self, filename: str) -> str: return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index ab2d0fba3b..7b6b2eedd6 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,9 +1,9 @@ import logging from collections.abc import Generator -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.client import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -53,7 +53,7 @@ class AwsS3Storage(BaseStorage): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") @@ -67,7 +67,9 @@ class AwsS3Storage(BaseStorage): yield from response["Body"].iter_chunks() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") + raise FileNotFoundError("file not found") + elif "reached max retries" in str(ex): + raise ValueError("please do not request the same file too frequently") else: raise diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index b26caa8671..2f8532f4f8 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,7 +27,7 @@ class AzureBlobStorage(BaseStorage): client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data = blob.download_blob().readall() + data: bytes = blob.download_blob().readall() return data def load_stream(self, filename: str) -> Generator: @@ -63,11 +63,11 @@ class AzureBlobStorage(BaseStorage): sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( - account_name=self.account_name, - account_key=self.account_key, + account_name=self.account_name or "", + account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) - return BlobServiceClient(account_url=self.account_url, credential=sas_token) + return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index e0d2140e91..b94efa08be 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import base64 import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials -from baidubce.bce_client_configuration import BceClientConfiguration -from baidubce.services.bos.bos_client import BosClient +from baidubce.auth.bce_credentials import BceCredentials # type: ignore +from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore +from baidubce.services.bos.bos_client import BosClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -36,7 +36,8 @@ class BaiduObsStorage(BaseStorage): def load_once(self, filename: str) -> bytes: response = self.client.get_object(bucket_name=self.bucket_name, key=filename) - return response.data.read() + data: bytes = response.data.read() + return data def load_stream(self, filename: str) -> Generator: response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 26b662d2f0..705639f42e 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import io import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage +from google.cloud import storage as google_cloud_storage # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -35,7 +35,7 @@ class GoogleCloudStorage(BaseStorage): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - data = blob.download_as_bytes() + data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 20be70ef83..07f1d19970 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient +from obs import ObsClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -23,7 +23,7 @@ class HuaweiObsStorage(BaseStorage): self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) def load_once(self, filename: str) -> bytes: - data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index dc71839c70..b78fc94dae 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -1,46 +1,56 @@ +import logging +import os from collections.abc import Generator from pathlib import Path -from urllib.parse import urlparse -import opendal +import opendal # type: ignore[import] +from dotenv import dotenv_values -from configs.middleware.storage.opendal_storage_config import OpenDALScheme from extensions.storage.base_storage import BaseStorage -S3_R2_HOSTNAME = "r2.cloudflarestorage.com" -S3_R2_COMPATIBLE_KWARGS = { - "delete_max_size": "700", - "disable_stat_with_override": "true", - "region": "auto", -} -S3_SSE_WITH_AWS_MANAGED_IAM_KWARGS = { - "server_side_encryption": "aws:kms", -} +logger = logging.getLogger(__name__) -def is_r2_endpoint(endpoint: str) -> bool: - if not endpoint: - return False +def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str = "OPENDAL_"): + kwargs = {} + config_prefix = prefix + scheme.upper() + "_" + for key, value in os.environ.items(): + if key.startswith(config_prefix): + kwargs[key[len(config_prefix) :].lower()] = value - parsed_url = urlparse(endpoint) - return bool(parsed_url.hostname and parsed_url.hostname.endswith(S3_R2_HOSTNAME)) + file_env_vars: dict = dotenv_values(env_file_path) or {} + for key, value in file_env_vars.items(): + if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: + kwargs[key[len(config_prefix) :].lower()] = value + + return kwargs class OpenDALStorage(BaseStorage): - def __init__(self, scheme: OpenDALScheme, **kwargs): - if scheme == OpenDALScheme.FS: - Path(kwargs["root"]).mkdir(parents=True, exist_ok=True) + def __init__(self, scheme: str, **kwargs): + kwargs = kwargs or _get_opendal_kwargs(scheme=scheme) + + if scheme == "fs": + root = kwargs.get("root", "storage") + Path(root).mkdir(parents=True, exist_ok=True) self.op = opendal.Operator(scheme=scheme, **kwargs) + logger.debug(f"opendal operator created with scheme {scheme}") + retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) + self.op = self.op.layer(retry_layer) + logger.debug("added retry layer to opendal operator") def save(self, filename: str, data: bytes) -> None: self.op.write(path=filename, bs=data) + logger.debug(f"file {filename} saved") def load_once(self, filename: str) -> bytes: if not self.exists(filename): raise FileNotFoundError("File not found") - return self.op.read(path=filename) + content: bytes = self.op.read(path=filename) + logger.debug(f"file {filename} loaded") + return content def load_stream(self, filename: str) -> Generator: if not self.exists(filename): @@ -50,6 +60,7 @@ class OpenDALStorage(BaseStorage): file = self.op.open(path=filename, mode="rb") while chunk := file.read(batch_size): yield chunk + logger.debug(f"file {filename} loaded as stream") def download(self, filename: str, target_filepath: str): if not self.exists(filename): @@ -57,16 +68,22 @@ class OpenDALStorage(BaseStorage): with Path(target_filepath).open("wb") as f: f.write(self.op.read(path=filename)) + logger.debug(f"file {filename} downloaded to {target_filepath}") def exists(self, filename: str) -> bool: # FIXME this is a workaround for opendal python-binding do not have a exists method and no better # error handler here when opendal python-binding has a exists method, we should use it # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs try: - return self.op.stat(path=filename).mode.is_file() - except Exception as e: + res: bool = self.op.stat(path=filename).mode.is_file() + logger.debug(f"file {filename} checked") + return res + except Exception: return False def delete(self, filename: str): if self.exists(filename): self.op.delete(path=filename) + logger.debug(f"file {filename} deleted") + return + logger.debug(f"file {filename} not found, skip delete") diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index b59f83b8de..82829f7fd5 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -27,7 +27,7 @@ class OracleOCIStorage(BaseStorage): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 9f7c69a9ae..711c3f7211 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -32,7 +32,7 @@ class SupabaseStorage(BaseStorage): self.client.storage.from_(self.bucket_name).upload(filename, data) def load_once(self, filename: str) -> bytes: - content = self.client.storage.from_(self.bucket_name).download(filename) + content: bytes = self.client.storage.from_(self.bucket_name).download(filename) return content def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 13a6c9239c..9cdd3e67f7 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client +from qcloud_cos import CosConfig, CosS3Client # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -25,7 +25,7 @@ class TencentCosStorage(BaseStorage): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index de82be04ea..55fe6545ec 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos +import tos # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -24,6 +24,8 @@ class VolcengineTosStorage(BaseStorage): def load_once(self, filename: str) -> bytes: data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + if not isinstance(data, bytes): + raise TypeError("Expected bytes, got {}".format(type(data).__name__)) return data def load_stream(self, filename: str) -> Generator: diff --git a/api/tests/integration_tests/model_runtime/gpustack/__init__.py b/api/factories/__init__.py similarity index 100% rename from api/tests/integration_tests/model_runtime/gpustack/__init__.py rename to api/factories/__init__.py diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index dccae186f0..99c7195b2c 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -64,7 +64,7 @@ def build_from_mapping( if not build_func: raise ValueError(f"Invalid file transfer method: {transfer_method}") - file = build_func( + file: File = build_func( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, @@ -72,7 +72,7 @@ def build_from_mapping( if config and not _is_file_valid_with_config( input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension, + file_extension=file.extension or "", file_transfer_method=file.transfer_method, config=config, ): @@ -116,8 +116,11 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: + upload_file_id = mapping.get("upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") stmt = select(UploadFile).where( - UploadFile.id == mapping.get("upload_file_id"), + UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) @@ -139,6 +142,7 @@ def _build_from_local_file( remote_url=row.source_url, related_id=mapping.get("upload_file_id"), size=row.size, + storage_key=row.key, ) @@ -168,6 +172,7 @@ def _build_from_remote_url( mime_type=mime_type, extension=extension, size=file_size, + storage_key="", ) @@ -220,6 +225,7 @@ def _build_from_tool_file( extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, + storage_key=tool_file.file_key, ) @@ -275,6 +281,7 @@ def _get_file_type_by_extension(extension: str) -> FileType | None: return FileType.AUDIO elif extension in DOCUMENT_EXTENSIONS: return FileType.DOCUMENT + return None def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 16a578728a..bbca8448ec 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, cast from uuid import uuid4 from configs import dify_config @@ -84,6 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -109,7 +111,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return result + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: @@ -164,10 +166,13 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=selector, + return cast( + Variable, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ), ) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d16..1c58b3a257 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34db..d40407bfcc 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index e6b49e873b..b14f8d0e73 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 5bd21be807..c54554a6de 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -85,7 +85,7 @@ message_detail_fields = { } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} - +status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer} model_config_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -166,6 +166,7 @@ conversation_with_summary_fields = { "message_count": fields.Integer, "user_feedback_stats": fields.Nested(feedback_stat_fields), "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "status_count": fields.Nested(status_count_fields), } conversation_with_summary_pagination_fields = { diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 983e50e73c..c6385efb5a 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376f..608672121e 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 533e3a0837..a74e6f54fb 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index a83ec7bc97..2b2ac6243f 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.dataset_fields import dataset_fields from libs.helper import TimestampField diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d1..aefa0b2758 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore simple_end_user_fields = { "id": fields.String, diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py index 2281460fe2..9cc4e14a05 100644 --- a/api/fields/external_dataset_fields.py +++ b/api/fields/external_dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index afaacc0568..f896c15f0f 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index f36e80f8d4..aaafcab8ab 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f6..16f265b9bb 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 1cf8e408d1..0c854c640c 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 5f6e7884a6..0571faab08 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.conversation_fields import message_file_fields from libs.helper import TimestampField diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab13..493d4b6cce 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2dd4cb45be..4413af3160 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57dd..986cd725f7 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,3 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index a53b546249..c45b33597b 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0d860d6f40..bd093d4063 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 8390c66556..ef59c57ec3 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -29,6 +29,7 @@ workflow_run_for_list_fields = { "created_at": TimestampField, "finished_at": TimestampField, "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_for_list_fields = { @@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = { "created_at": TimestampField, "finished_at": TimestampField, "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_pagination_fields = { @@ -79,6 +81,19 @@ workflow_run_detail_fields = { "exceptions_count": fields.Integer, } +retry_event_field = { + "elapsed_time": fields.Float, + "status": fields.String, + "inputs": fields.Raw(attribute="inputs"), + "process_data": fields.Raw(attribute="process_data"), + "outputs": fields.Raw(attribute="outputs"), + "metadata": fields.Raw(attribute="metadata"), + "llm_usage": fields.Raw(attribute="llm_usage"), + "error": fields.String, + "retry_index": fields.Integer, +} + + workflow_run_node_execution_fields = { "id": fields.String, "index": fields.Integer, diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 179617ac0a..922d2d9cd3 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,8 +1,9 @@ import re import sys +from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message +from flask_restful import Api, http_status_message # type: ignore from werkzeug.datastructures import Headers from werkzeug.exceptions import HTTPException @@ -84,7 +85,7 @@ class ExternalApi(Api): # record the exception in the logs when we have a server error of status code: 500 if status_code and status_code >= 500: - exc_info = sys.exc_info() + exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None current_app.log_exception(exc_info) @@ -100,7 +101,7 @@ class ExternalApi(Api): resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) elif status_code == 400: if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message").items())[0] + param_key, param_value = list(data.get("message", {}).items())[0] data = {"code": "invalid_param", "message": param_value, "params": param_key} else: if "code" not in data: diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 83f9c74e33..2dae87e171 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ from hashlib import sha1 import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 +import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -191,12 +191,12 @@ class PKCS1OAepCipher: # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) + invalid = bord(y) | int(one_pos < hLen) # type: ignore hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) + invalid |= bord(x) # type: ignore for x in db[hLen:one_pos]: - invalid |= bord(x) + invalid |= bord(x) # type: ignore if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index 78f36bc58b..884d728492 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -248,13 +248,13 @@ class TokenManager: if token_data_json is None: logging.warning(f"{token_type} token {token} not found with key {key}") return None - token_data = json.loads(token_data_json) + token_data: Optional[dict[str, Any]] = json.loads(token_data_json) return token_data @classmethod def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: key = cls._get_account_token_key(account_id, token_type) - current_token = redis_client.get(key) + current_token: Optional[str] = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41c5d20c4b..9ab53b6294 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict: ends = ["```", "``", "`", "}"] end_index = -1 start_index = 0 + parsed: dict = {} for s in starts: start_index = json_string.find(s) if start_index != -1: @@ -27,7 +28,7 @@ def parse_json_markdown(json_string: str) -> dict: extracted_content = json_string[start_index:end_index].strip() parsed = json.loads(extracted_content) else: - raise Exception("Could not find JSON block in the output.") + raise ValueError("could not find json block in the output.") return parsed @@ -36,10 +37,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserError(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"got invalid json object. error: {e}") for key in expected_keys: if key not in json_obj: raise OutputParserError( - f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/login.py b/api/libs/login.py index ab0ac3beb2..174640d986 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,8 +1,9 @@ from functools import wraps +from typing import Any from flask import current_app, g, has_request_context, request -from flask_login import user_logged_in -from flask_login.config import EXEMPT_METHODS +from flask_login import user_logged_in # type: ignore +from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -13,7 +14,7 @@ from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user = LocalProxy(lambda: _get_user()) +current_user: Any = LocalProxy(lambda: _get_user()) def login_required(func): @@ -80,12 +81,12 @@ def login_required(func): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: - return current_app.login_manager.unauthorized() + return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 @@ -99,7 +100,7 @@ def login_required(func): def _get_user() -> EndUser | Account | None: if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() + current_app.login_manager._load_user() # type: ignore return g._login_user diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 6b6919de24..df75b55019 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -77,9 +77,9 @@ class GitHubOAuth(OAuth): email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email["primary"] == True), None) + primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) - return {**user_info, "email": primary_email["email"]} + return {**user_info, "email": primary_email.get("email", "")} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: email = raw_info.get("email") @@ -130,4 +130,4 @@ class GoogleOAuth(OAuth): return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) + return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 1d39abd8fa..0c872a0066 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,8 +1,9 @@ import datetime import urllib.parse +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from extensions.ext_database import db from models.source import DataSourceOauthBinding @@ -226,7 +227,7 @@ class NotionOAuth(OAuthDataSource): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "page", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } @@ -281,7 +282,7 @@ class NotionOAuth(OAuthDataSource): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "database", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py index d356def418..e4d63fd314 100644 --- a/api/libs/threadings_utils.py +++ b/api/libs/threadings_utils.py @@ -9,8 +9,8 @@ def apply_gevent_threading_patch(): :return: """ if not dify_config.DEBUG: - from gevent import monkey - from grpc.experimental import gevent as grpc_gevent + from gevent import monkey # type: ignore + from grpc.experimental import gevent as grpc_gevent # type: ignore # gevent monkey.patch_all() diff --git a/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py new file mode 100644 index 0000000000..881a9e3c1e --- /dev/null +++ b/api/migrations/versions/2024_12_19_1746-11b07f66c737_remove_unused_tool_providers.py @@ -0,0 +1,39 @@ +"""remove unused tool_providers + +Revision ID: 11b07f66c737 +Revises: cf8f4fc45278 +Create Date: 2024-12-19 17:46:25.780116 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '11b07f66c737' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tool_providers') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_providers', + sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), + sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False), + sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False), + sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True), + sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'), + sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + ) + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py new file mode 100644 index 0000000000..814dec423c --- /dev/null +++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py @@ -0,0 +1,37 @@ +"""add retry_index field to node-execution model +Revision ID: e1944c35e15e +Revises: 11b07f66c737 +Create Date: 2024-12-20 06:28:30.287197 +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e1944c35e15e' +down_revision = '11b07f66c737' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # We don't need these fields anymore, but this file is already merged into the main branch, + # so we need to keep this file for the sake of history, and this change will be reverted in the next migration. + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + pass + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + # batch_op.drop_column('retry_index') + pass + + # ### end Alembic commands ### \ No newline at end of file diff --git a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py new file mode 100644 index 0000000000..ea129d15f7 --- /dev/null +++ b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py @@ -0,0 +1,34 @@ +"""remove workflow_node_executions.retry_index if exists + +Revision ID: d7999dfa4aae +Revises: e1944c35e15e +Create Date: 2024-12-23 11:54:15.344543 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy import inspect + + +# revision identifiers, used by Alembic. +revision = 'd7999dfa4aae' +down_revision = 'e1944c35e15e' +branch_labels = None +depends_on = None + + +def upgrade(): + # Check if column exists before attempting to remove it + conn = op.get_bind() + inspector = inspect(conn) + has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')] + + if has_column: + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_column('retry_index') + + +def downgrade(): + # No downgrade needed as we don't want to restore the column + pass diff --git a/api/models/__init__.py b/api/models/__init__.py index 61a38870cf..b0b9880ca4 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,53 +1,187 @@ -from .account import Account, AccountIntegrate, InvitationCode, Tenant -from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .account import ( + Account, + AccountIntegrate, + AccountStatus, + InvitationCode, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) +from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .dataset import ( + AppDatasetJoin, + Dataset, + DatasetCollectionBinding, + DatasetKeywordTable, + DatasetPermission, + DatasetPermissionEnum, + DatasetProcessRule, + DatasetQuery, + Document, + DocumentSegment, + Embedding, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, + TidbAuthBinding, + Whitelist, +) +from .engine import db +from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( + ApiRequest, ApiToken, App, + AppAnnotationHitHistory, + AppAnnotationSetting, AppMode, + AppModelConfig, Conversation, + DatasetRetrieverResource, + DifySetup, EndUser, + IconType, InstalledApp, Message, + MessageAgentThought, MessageAnnotation, + MessageChain, + MessageFeedback, MessageFile, + OperationLog, RecommendedApp, Site, + Tag, + TagBinding, + TraceAppConfig, UploadFile, ) -from .source import DataSourceOauthBinding -from .tools import ToolFile +from .provider import ( + LoadBalancingModelConfig, + Provider, + ProviderModel, + ProviderModelSetting, + ProviderOrder, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) +from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from .task import CeleryTask, CeleryTaskSet +from .tools import ( + ApiToolProvider, + BuiltinToolProvider, + PublishedAppTool, + ToolConversationVariables, + ToolFile, + ToolLabelBinding, + ToolModelInvoke, + WorkflowToolProvider, +) +from .web import PinnedConversation, SavedMessage from .workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowAppLogCreatedFrom, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, WorkflowRun, + WorkflowRunStatus, + WorkflowType, ) __all__ = [ + "APIBasedExtension", + "APIBasedExtensionPoint", "Account", "AccountIntegrate", + "AccountStatus", + "ApiRequest", "ApiToken", + "ApiToolProvider", # Added "App", + "AppAnnotationHitHistory", + "AppAnnotationSetting", + "AppDatasetJoin", "AppMode", + "AppModelConfig", + "BuiltinToolProvider", # Added + "CeleryTask", + "CeleryTaskSet", "Conversation", "ConversationVariable", + "CreatedByRole", + "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", + "DatasetCollectionBinding", + "DatasetKeywordTable", + "DatasetPermission", + "DatasetPermissionEnum", "DatasetProcessRule", + "DatasetQuery", + "DatasetRetrieverResource", + "DifySetup", "Document", "DocumentSegment", + "Embedding", "EndUser", + "ExternalKnowledgeApis", + "ExternalKnowledgeBindings", + "IconType", "InstalledApp", "InvitationCode", + "LoadBalancingModelConfig", "Message", + "MessageAgentThought", "MessageAnnotation", + "MessageChain", + "MessageFeedback", "MessageFile", + "OperationLog", + "PinnedConversation", + "Provider", + "ProviderModel", + "ProviderModelSetting", + "ProviderOrder", + "ProviderQuotaType", + "ProviderType", + "PublishedAppTool", "RecommendedApp", + "SavedMessage", "Site", + "Tag", + "TagBinding", "Tenant", + "TenantAccountJoin", + "TenantAccountJoinRole", + "TenantAccountRole", + "TenantDefaultModel", + "TenantPreferredModelProvider", + "TenantStatus", + "TidbAuthBinding", + "ToolConversationVariables", "ToolFile", + "ToolLabelBinding", + "ToolModelInvoke", + "TraceAppConfig", "UploadFile", + "UserFrom", + "Whitelist", "Workflow", "WorkflowAppLog", + "WorkflowAppLogCreatedFrom", + "WorkflowNodeExecution", + "WorkflowNodeExecutionStatus", + "WorkflowNodeExecutionTriggeredFrom", "WorkflowRun", + "WorkflowRunStatus", + "WorkflowRunTriggeredFrom", + "WorkflowToolProvider", + "WorkflowType", + "db", ] diff --git a/api/models/account.py b/api/models/account.py index f040ac39f5..16e229192d 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -4,9 +4,10 @@ import json from flask_login import UserMixin from sqlalchemy.orm import Mapped, mapped_column -from extensions.ext_database import db from models.base import Base +from sqlalchemy import func +from .engine import db from .types import StringUUID @@ -33,11 +34,11 @@ class Account(UserMixin, Base): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def is_password_set(self): @@ -45,7 +46,8 @@ class Account(UserMixin, Base): @property def current_tenant(self): - return self._current_tenant + # FIXME: fix the type error later, because the type is important maybe cause some bugs + return self._current_tenant # type: ignore @current_tenant.setter def current_tenant(self, value: "Tenant"): @@ -78,7 +80,7 @@ class Account(UserMixin, Base): tenant.current_role = ta.role else: tenant = None - except: + except Exception: tenant = None self._current_tenant = tenant @@ -92,7 +94,7 @@ class Account(UserMixin, Base): return AccountStatus(status_str) @classmethod - def get_by_openid(cls, provider: str, open_id: str) -> db.Model: + def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) @@ -189,7 +191,7 @@ class TenantAccountRole(enum.StrEnum): } -class Tenant(db.Model): +class Tenant(db.Model): # type: ignore[name-defined] __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -199,8 +201,8 @@ class Tenant(db.Model): plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -225,7 +227,7 @@ class TenantAccountJoinRole(enum.Enum): DATASET_OPERATOR = "dataset_operator" -class TenantAccountJoin(db.Model): +class TenantAccountJoin(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_account_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -240,11 +242,11 @@ class TenantAccountJoin(db.Model): current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class AccountIntegrate(db.Model): +class AccountIntegrate(db.Model): # type: ignore[name-defined] __tablename__ = "account_integrates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -257,11 +259,11 @@ class AccountIntegrate(db.Model): provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class InvitationCode(db.Model): +class InvitationCode(db.Model): # type: ignore[name-defined] __tablename__ = "invitation_codes" __table_args__ = ( db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 97173747af..6b6d808710 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,7 +1,8 @@ import enum -from extensions.ext_database import db +from sqlalchemy import func +from .engine import db from .types import StringUUID @@ -12,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_MODERATION_OUTPUT = "app.moderation.output" -class APIBasedExtension(db.Model): +class APIBasedExtension(db.Model): # type: ignore[name-defined] __tablename__ = "api_based_extensions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), @@ -24,4 +25,4 @@ class APIBasedExtension(db.Model): name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 8ab957e875..b9b41dcf47 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,16 +9,17 @@ import pickle import re import time from json import JSONDecodeError +from typing import Any, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db from extensions.ext_storage import storage from .account import Account +from .engine import db from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID @@ -29,7 +30,7 @@ class DatasetPermissionEnum(enum.StrEnum): PARTIAL_TEAM = "partial_members" -class Dataset(db.Model): +class Dataset(db.Model): # type: ignore[name-defined] __tablename__ = "datasets" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), @@ -50,9 +51,9 @@ class Dataset(db.Model): indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -200,7 +201,7 @@ class Dataset(db.Model): return f"Vector_index_{normalized_dataset_id}_Node" -class DatasetProcessRule(db.Model): +class DatasetProcessRule(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_process_rules" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -212,11 +213,11 @@ class DatasetProcessRule(db.Model): mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES = { + AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -242,7 +243,7 @@ class DatasetProcessRule(db.Model): return None -class Document(db.Model): +class Document(db.Model): # type: ignore[name-defined] __tablename__ = "documents" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_pkey"), @@ -264,7 +265,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -303,7 +304,7 @@ class Document(db.Model): archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) @@ -492,7 +493,7 @@ class Document(db.Model): ) -class DocumentSegment(db.Model): +class DocumentSegment(db.Model): # type: ignore[name-defined] __tablename__ = "document_segments" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_segment_pkey"), @@ -527,9 +528,9 @@ class DocumentSegment(db.Model): disabled_by = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -604,7 +605,7 @@ class DocumentSegment(db.Model): return text -class AppDatasetJoin(db.Model): +class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), @@ -621,7 +622,7 @@ class AppDatasetJoin(db.Model): return db.session.get(App, self.app_id) -class DatasetQuery(db.Model): +class DatasetQuery(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_queries" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), @@ -638,7 +639,7 @@ class DatasetQuery(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class DatasetKeywordTable(db.Model): +class DatasetKeywordTable(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_keyword_tables" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), @@ -683,7 +684,7 @@ class DatasetKeywordTable(db.Model): return None -class Embedding(db.Model): +class Embedding(db.Model): # type: ignore[name-defined] __tablename__ = "embeddings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -697,17 +698,17 @@ class Embedding(db.Model): ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) + return cast(list[float], pickle.loads(self.embedding)) -class DatasetCollectionBinding(db.Model): +class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_collection_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), @@ -719,10 +720,10 @@ class DatasetCollectionBinding(db.Model): model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TidbAuthBinding(db.Model): +class TidbAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tidb_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -739,10 +740,10 @@ class TidbAuthBinding(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) account = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Whitelist(db.Model): +class Whitelist(db.Model): # type: ignore[name-defined] __tablename__ = "whitelists" __table_args__ = ( db.PrimaryKeyConstraint("id", name="whitelists_pkey"), @@ -751,10 +752,10 @@ class Whitelist(db.Model): id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetPermission(db.Model): +class DatasetPermission(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_permissions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), @@ -768,10 +769,10 @@ class DatasetPermission(db.Model): account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ExternalKnowledgeApis(db.Model): +class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_apis" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), @@ -785,9 +786,9 @@ class ExternalKnowledgeApis(db.Model): tenant_id = db.Column(StringUUID, nullable=False) settings = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -824,7 +825,7 @@ class ExternalKnowledgeApis(db.Model): return dataset_bindings -class ExternalKnowledgeBindings(db.Model): +class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), @@ -840,6 +841,6 @@ class ExternalKnowledgeBindings(db.Model): dataset_id = db.Column(StringUUID, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/engine.py b/api/models/engine.py new file mode 100644 index 0000000000..dda93bc941 --- /dev/null +++ b/api/models/engine.py @@ -0,0 +1,13 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import MetaData + +POSTGRES_INDEXES_NAMING_CONVENTION = { + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", +} + +metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) +db = SQLAlchemy(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 2423513085..8a707a59e5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,5 +1,4 @@ import json -import logging import re import uuid from collections.abc import Mapping @@ -22,20 +21,24 @@ from flask import request from flask_login import UserMixin from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy.orm import Mapped, Session, mapped_column +from typing import TYPE_CHECKING, cast + from configs import dify_config from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from extensions.ext_database import db from libs.helper import generate_string from models.base import Base from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus from .account import Account, Tenant +from .engine import db from .types import StringUUID -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from .workflow import Workflow class DifySetup(Base): @@ -43,7 +46,7 @@ class DifySetup(Base): __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -96,11 +99,11 @@ class App(Base): is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) - max_active_requests = db.Column(db.Integer, nullable=True) + max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -165,7 +168,7 @@ class App(Base): if self.mode == AppMode.CHAT.value and self.is_agent: return AppMode.AGENT_CHAT.value - return self.mode + return str(self.mode) @property def deleted_tools(self) -> list: @@ -207,7 +210,6 @@ class App(Base): provider_id = GenericProviderID(provider_id, is_hardcoded) except Exception: - logger.exception(f"Invalid builtin provider id: {provider_id}") continue builtin_provider_ids.append(provider_id) @@ -307,9 +309,9 @@ class AppModelConfig(Base): model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -425,7 +427,7 @@ class AppModelConfig(Base): @property def dataset_configs_dict(self) -> dict: if self.dataset_configs: - dataset_configs = json.loads(self.dataset_configs) + dataset_configs: dict = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: @@ -566,8 +568,8 @@ class RecommendedApp(Base): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -591,7 +593,7 @@ class InstalledApp(Base): position = db.Column(db.Integer, nullable=False, default=0) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -632,8 +634,8 @@ class Conversation(Base): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( @@ -645,13 +647,29 @@ class Conversation(Base): @property def inputs(self): inputs = self._inputs.copy() + + # Convert file mapping to File object for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - inputs[key] = File.model_validate(value) + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) elif isinstance(value, list) and all( isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value ): - inputs[key] = [File.model_validate(item) for item in value] + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) + return inputs @inputs.setter @@ -667,6 +685,8 @@ class Conversation(Base): @property def model_config(self): model_config = {} + app_model_config: Optional[AppModelConfig] = None + if self.mode == AppMode.ADVANCED_CHAT.value: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) @@ -678,6 +698,7 @@ class Conversation(Base): if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) + assert app_model_config is not None, "app model config not found" model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -764,6 +785,31 @@ class Conversation(Base): return {"like": like, "dislike": dislike} + @property + def status_count(self): + messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() + status_counts = { + WorkflowRunStatus.RUNNING: 0, + WorkflowRunStatus.SUCCEEDED: 0, + WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.STOPPED: 0, + WorkflowRunStatus.PARTIAL_SUCCESSED: 0, + } + + for message in messages: + if message.workflow_run: + status_counts[message.workflow_run.status] += 1 + + return ( + { + "success": status_counts[WorkflowRunStatus.SUCCEEDED], + "failed": status_counts[WorkflowRunStatus.FAILED], + "partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCESSED], + } + if messages + else None + ) + @property def first_message(self): return db.session.query(Message).filter(Message.conversation_id == self.id).first() @@ -834,8 +880,8 @@ class Message(Base): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -843,12 +889,25 @@ class Message(Base): def inputs(self): inputs = self._inputs.copy() for key, value in inputs.items(): + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - inputs[key] = File.model_validate(value) + if value["transfer_method"] == FileTransferMethod.TOOL_FILE: + value["tool_file_id"] = value["related_id"] + elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE: + value["upload_file_id"] = value["related_id"] + inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) elif isinstance(value, list) and all( isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value ): - inputs[key] = [File.model_validate(item) for item in value] + inputs[key] = [] + for item in value: + if item["transfer_method"] == FileTransferMethod.TOOL_FILE: + item["tool_file_id"] = item["related_id"] + elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE: + item["upload_file_id"] = item["related_id"] + inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) return inputs @inputs.setter @@ -1022,7 +1081,7 @@ class Message(Base): if not current_app: raise ValueError(f"App {self.app_id} not found") - files: list[File] = [] + files = [] for message_file in message_files: if message_file.transfer_method == "local_file": if message_file.upload_file_id is None: @@ -1147,8 +1206,8 @@ class MessageFeedback(Base): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1194,9 +1253,7 @@ class MessageFile(Base): upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(Base): @@ -1216,8 +1273,8 @@ class MessageAnnotation(Base): content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1246,7 +1303,7 @@ class AppAnnotationHitHistory(Base): source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) @@ -1280,9 +1337,9 @@ class AppAnnotationSetting(Base): score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_account(self): @@ -1328,9 +1385,9 @@ class OperationLog(Base): account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(Base, UserMixin): @@ -1349,8 +1406,8 @@ class EndUser(Base, UserMixin): name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Site(Base): @@ -1381,9 +1438,9 @@ class Site(Base): prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) code = db.Column(db.String(255)) @property @@ -1425,7 +1482,7 @@ class ApiToken(Base): type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1456,9 +1513,7 @@ class UploadFile(Base): db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) @@ -1515,7 +1570,7 @@ class ApiRequest(Base): request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(Base): @@ -1573,7 +1628,7 @@ class MessageAgentThought(Base): @property def files(self) -> list: if self.message_files: - return json.loads(self.message_files) + return cast(list[Any], json.loads(self.message_files)) else: return [] @@ -1585,20 +1640,20 @@ class MessageAgentThought(Base): def tool_labels(self) -> dict: try: if self.tool_labels_str: - return json.loads(self.tool_labels_str) + return cast(dict, json.loads(self.tool_labels_str)) else: return {} - except Exception as e: + except Exception: return {} @property def tool_meta(self) -> dict: try: if self.tool_meta_str: - return json.loads(self.tool_meta_str) + return cast(dict, json.loads(self.tool_meta_str)) else: return {} - except Exception as e: + except Exception: return {} @property @@ -1619,7 +1674,7 @@ class MessageAgentThought(Base): return result else: return {tool: {} for tool in tools} - except Exception as e: + except Exception: return {} @property @@ -1640,9 +1695,11 @@ class MessageAgentThought(Base): return result else: return {tool: {} for tool in tools} - except Exception as e: + except Exception: if self.observation: return dict.fromkeys(tools, self.observation) + else: + return {} class DatasetRetrieverResource(Base): @@ -1687,7 +1744,7 @@ class Tag(Base): type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(Base): @@ -1703,7 +1760,7 @@ class TagBinding(Base): tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(Base): @@ -1717,8 +1774,10 @@ class TraceAppConfig(Base): app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property diff --git a/api/models/provider.py b/api/models/provider.py index 58c1978573..b7889be8b5 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,8 +1,8 @@ from enum import Enum - -from extensions.ext_database import db from models.base import Base +from sqlalchemy import func +from .engine import db from .types import StringUUID @@ -18,6 +18,24 @@ class ProviderType(Enum): raise ValueError(f"No matching enum found for value '{value}'") +class ProviderQuotaType(Enum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value): + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + class Provider(Base): """ Provider model representing the API providers and their configurations. @@ -44,8 +62,8 @@ class Provider(Base): quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -92,8 +110,8 @@ class ProviderModel(Base): model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -108,8 +126,8 @@ class TenantDefaultModel(Base): provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -123,8 +141,8 @@ class TenantPreferredModelProvider(Base): tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -148,8 +166,8 @@ class ProviderOrder(Base): paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -170,8 +188,8 @@ class ProviderModelSetting(Base): model_type = db.Column(db.String(40), nullable=False) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -193,5 +211,5 @@ class LoadBalancingModelConfig(Base): name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index efd94227d0..7f976b5ed4 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,14 +1,16 @@ import json +from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB -from extensions.ext_database import db from models.base import Base + +from .engine import db from .types import StringUUID -class DataSourceOauthBinding(Base): +class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), @@ -21,8 +23,8 @@ class DataSourceOauthBinding(Base): access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) @@ -39,8 +41,8 @@ class DataSourceApiKeyAuthBinding(Base): category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): diff --git a/api/models/task.py b/api/models/task.py index b6783873df..b42af31855 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,9 +1,9 @@ from datetime import UTC, datetime -from celery import states +from celery import states # type: ignore -from extensions.ext_database import db from models.base import Base +from .engine import db class CeleryTask(Base): diff --git a/api/models/tool.py b/api/models/tool.py deleted file mode 100644 index d70c905851..0000000000 --- a/api/models/tool.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -from enum import Enum - -from extensions.ext_database import db -from models.base import Base - -from .types import StringUUID - - -class ToolProviderName(Enum): - SERPAPI = "serpapi" - - @staticmethod - def value_of(value): - for member in ToolProviderName: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class ToolProvider(Base): - __tablename__ = "tool_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), - db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), - ) - - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - tool_name = db.Column(db.String(40), nullable=False) - encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - - @property - def credentials_is_set(self): - """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ - return self.encrypted_credentials is not None - - @property - def credentials(self): - """ - Returns the decrypted config. - """ - return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None diff --git a/api/models/tools.py b/api/models/tools.py index 248e28e0b9..0fcd87d2b9 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -4,15 +4,15 @@ from typing import Optional import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from extensions.ext_database import db from models.base import Base +from .engine import db from .model import Account, App, Tenant from .types import StringUUID @@ -85,8 +85,8 @@ class ApiToolProvider(Base): # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -98,7 +98,7 @@ class ApiToolProvider(Base): @property def credentials(self) -> dict: - return json.loads(self.credentials_str) + return dict(json.loads(self.credentials_str)) @property def user(self) -> Account | None: @@ -224,8 +224,8 @@ class ToolModelInvoke(Base): provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -252,12 +252,12 @@ class ToolConversationVariables(Base): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> dict: - return json.loads(self.variables_str) + return dict(json.loads(self.variables_str)) class ToolFile(Base): diff --git a/api/models/web.py b/api/models/web.py index 934008a443..2b8caf0492 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,6 +1,8 @@ -from extensions.ext_database import db from models.base import Base +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column +from .engine import db from .model import Message from .types import StringUUID @@ -17,7 +19,7 @@ class SavedMessage(Base): message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -33,7 +35,7 @@ class PinnedConversation(Base): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 832206e5bc..6e2bdf2392 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: from models.model import AppMode from enum import StrEnum +from typing import TYPE_CHECKING import sqlalchemy as sa from sqlalchemy import Index, PrimaryKeyConstraint, func @@ -16,15 +17,18 @@ import contexts from constants import HIDDEN_VALUE from core.helper import encrypter from core.variables import SecretVariable, Variable -from extensions.ext_database import db from factories import variable_factory from libs import helper from models.base import Base from models.enums import CreatedByRole from .account import Account +from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from models.model import AppMode + class WorkflowType(Enum): """ @@ -108,12 +112,13 @@ class Workflow(Base): graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() + db.DateTime, + nullable=False, + default=datetime.now(UTC).replace(tzinfo=None), + server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -186,7 +191,7 @@ class Workflow(Base): self._features = value @property - def features_dict(self) -> Mapping[str, Any]: + def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} def user_input_form(self, to_old_structure: bool = False) -> list: @@ -203,7 +208,7 @@ class Workflow(Base): return [] # get user_input_form from start node - variables = start_node.get("data", {}).get("variables", []) + variables: list[Any] = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] @@ -250,11 +255,13 @@ class Workflow(Base): ] # decrypt secret variables value - decrypt_func = ( - lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - if isinstance(var, SecretVariable) - else var - ) + def decrypt_func(var): + return ( + var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) + if isinstance(var, SecretVariable) + else var + ) + results = list(map(decrypt_func, results)) return results @@ -278,11 +285,13 @@ class Workflow(Base): value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value - encrypt_func = ( - lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) - if isinstance(var, SecretVariable) - else var - ) + def encrypt_func(var): + return ( + var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) + if isinstance(var, SecretVariable) + else var + ) + encrypted_vars = list(map(encrypt_func, value)) environment_variables_json = json.dumps( {var.name: var.model_dump() for var in encrypted_vars}, @@ -404,14 +413,14 @@ class WorkflowRun(Base): graph = db.Column(db.Text) inputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded - outputs: Mapped[str] = mapped_column(sa.Text, default="{}") + outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @@ -534,6 +543,7 @@ class WorkflowNodeExecutionStatus(Enum): SUCCEEDED = "succeeded" FAILED = "failed" EXCEPTION = "exception" + RETRY = "retry" @classmethod def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": @@ -640,7 +650,7 @@ class WorkflowNodeExecution(Base): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -758,7 +768,7 @@ class WorkflowAppLog(Base): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -789,7 +799,7 @@ class ConversationVariable(Base): conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) data = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) diff --git a/api/mypy.ini b/api/mypy.ini new file mode 100644 index 0000000000..2c754f9fcd --- /dev/null +++ b/api/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +warn_return_any = True +warn_unused_configs = True +check_untyped_defs = True +exclude = (?x)( + core/tools/provider/builtin/ + | core/model_runtime/model_providers/ + | tests/ + | migrations/ + ) \ No newline at end of file diff --git a/api/poetry.lock b/api/poetry.lock index 2cdd07202c..b42eb22dd4 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiofiles" @@ -955,6 +955,10 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -967,8 +971,14 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, + {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, + {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -979,8 +989,24 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, + {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, + {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, + {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, + {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -990,6 +1016,10 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -1001,6 +1031,10 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -1013,6 +1047,10 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -1025,6 +1063,10 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -5601,6 +5643,58 @@ files = [ {file = "multitasking-0.0.11.tar.gz", hash = "sha256:4d6bc3cc65f9b2dca72fb5a787850a88dae8f620c2b36ae9b55248e51bcd6026"}, ] +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -6495,6 +6589,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"}, + {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "pathos" version = "0.3.3" @@ -7482,23 +7591,24 @@ image = ["Pillow (>=8.0.0)"] [[package]] name = "pypdfium2" -version = "4.17.0" +version = "4.30.0" description = "Python bindings to PDFium" optional = false python-versions = ">=3.6" files = [ - {file = "pypdfium2-4.17.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:e9ed42d5a5065ae41ae3ead3cd642e1f21b6039e69ccc204e260e218e91cd7e1"}, - {file = "pypdfium2-4.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:0a3b5a8eca53a1e68434969821b70bd2bc9ac2b70e58daf516c6ff0b6b5779e7"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:854e04b51205466ec415b86588fe5dc593e9ca3e8e15b5aa05978c5352bd57d2"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_armv7l.whl", hash = "sha256:9ff8707b28568e9585bdf9a96b7a8a9f91c0b5ad05af119b49381dad89983364"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_i686.whl", hash = "sha256:09ecbef6212993db0b5460cfd46d6b157a921ff45c97b0764e6fe8ea2e8cdebf"}, - {file = "pypdfium2-4.17.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:f680e469b79c71c3fb086d7ced8361fbd66f4cd7b0ad08ff888289fe6743ab32"}, - {file = "pypdfium2-4.17.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ba7a7da48fbf0f1aaa903dac7d0e62186d6e8ae9a78b7b7b836d3f1b3d1be5d"}, - {file = "pypdfium2-4.17.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:451752170caf59d4b4572b527c2858dfff96eb1da35f2822c66cdce006dd4eae"}, - {file = "pypdfium2-4.17.0-py3-none-win32.whl", hash = "sha256:4930cfa793298214fa644c6986f6466e21f98eba3f338b4577614ebd8aa34af5"}, - {file = "pypdfium2-4.17.0-py3-none-win_amd64.whl", hash = "sha256:99de7f336e967dea4d324484f581fff55db1eb3c8e90baa845567dd9a3cc84f3"}, - {file = "pypdfium2-4.17.0-py3-none-win_arm64.whl", hash = "sha256:9381677b489c13d64ea4f8cbf6ebfc858216b052883e01e40fa993c2818a078e"}, - {file = "pypdfium2-4.17.0.tar.gz", hash = "sha256:2a2b3273c4614ee2004df60ace5f387645f843418ae29f379408ee11560241c0"}, + {file = "pypdfium2-4.30.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:b33ceded0b6ff5b2b93bc1fe0ad4b71aa6b7e7bd5875f1ca0cdfb6ba6ac01aab"}, + {file = "pypdfium2-4.30.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:4e55689f4b06e2d2406203e771f78789bd4f190731b5d57383d05cf611d829de"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e6e50f5ce7f65a40a33d7c9edc39f23140c57e37144c2d6d9e9262a2a854854"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3d0dd3ecaffd0b6dbda3da663220e705cb563918249bda26058c6036752ba3a2"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc3bf29b0db8c76cdfaac1ec1cde8edf211a7de7390fbf8934ad2aa9b4d6dfad"}, + {file = "pypdfium2-4.30.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1f78d2189e0ddf9ac2b7a9b9bd4f0c66f54d1389ff6c17e9fd9dc034d06eb3f"}, + {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:5eda3641a2da7a7a0b2f4dbd71d706401a656fea521b6b6faa0675b15d31a163"}, + {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_i686.whl", hash = "sha256:0dfa61421b5eb68e1188b0b2231e7ba35735aef2d867d86e48ee6cab6975195e"}, + {file = "pypdfium2-4.30.0-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:f33bd79e7a09d5f7acca3b0b69ff6c8a488869a7fab48fdf400fec6e20b9c8be"}, + {file = "pypdfium2-4.30.0-py3-none-win32.whl", hash = "sha256:ee2410f15d576d976c2ab2558c93d392a25fb9f6635e8dd0a8a3a5241b275e0e"}, + {file = "pypdfium2-4.30.0-py3-none-win_amd64.whl", hash = "sha256:90dbb2ac07be53219f56be09961eb95cf2473f834d01a42d901d13ccfad64b4c"}, + {file = "pypdfium2-4.30.0-py3-none-win_arm64.whl", hash = "sha256:119b2969a6d6b1e8d55e99caaf05290294f2d0fe49c12a3f17102d01c441bd29"}, + {file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"}, ] [[package]] @@ -9212,13 +9322,13 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlparse" -version = "0.5.2" +version = "0.5.3" description = "A non-validating SQL parser." optional = false python-versions = ">=3.8" files = [ - {file = "sqlparse-0.5.2-py3-none-any.whl", hash = "sha256:e99bc85c78160918c3e1d9230834ab8d80fc06c59d03f8db2618f65f65dda55e"}, - {file = "sqlparse-0.5.2.tar.gz", hash = "sha256:9e37b35e16d1cc652a2545f0997c1deb23ea28fa1f3eefe609eee3063c3b105f"}, + {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"}, + {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"}, ] [package.extras] @@ -9804,6 +9914,17 @@ rich = ">=10.11.0" shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" +[[package]] +name = "types-pytz" +version = "2024.2.0.20241003" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.2.0.20241003.tar.gz", hash = "sha256:575dc38f385a922a212bac00a7d6d2e16e141132a3c955078f4a4fd13ed6cb44"}, + {file = "types_pytz-2024.2.0.20241003-py3-none-any.whl", hash = "sha256:3e22df1336c0c6ad1d29163c8fda82736909eb977281cb823c57f8bae07118b7"}, +] + [[package]] name = "types-requests" version = "2.32.0.20241016" @@ -10270,82 +10391,82 @@ ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic [[package]] name = "watchfiles" -version = "1.0.0" +version = "1.0.3" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.9" files = [ - {file = "watchfiles-1.0.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1d19df28f99d6a81730658fbeb3ade8565ff687f95acb59665f11502b441be5f"}, - {file = "watchfiles-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:28babb38cf2da8e170b706c4b84aa7e4528a6fa4f3ee55d7a0866456a1662041"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ab123135b2f42517f04e720526d41448667ae8249e651385afb5cda31fedc0"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:13a4f9ee0cd25682679eea5c14fc629e2eaa79aab74d963bc4e21f43b8ea1877"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e1d9284cc84de7855fcf83472e51d32daf6f6cecd094160192628bc3fee1b78"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ee5edc939f53466b329bbf2e58333a5461e6c7b50c980fa6117439e2c18b42d"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dccfc70480087567720e4e36ec381bba1ed68d7e5f368fe40c93b3b1eba0105"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c83a6d33a9eda0af6a7470240d1af487807adc269704fe76a4972dd982d16236"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:905f69aad276639eff3893759a07d44ea99560e67a1cf46ff389cd62f88872a2"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:09551237645d6bff3972592f2aa5424df9290e7a2e15d63c5f47c48cde585935"}, - {file = "watchfiles-1.0.0-cp310-none-win32.whl", hash = "sha256:d2b39aa8edd9e5f56f99a2a2740a251dc58515398e9ed5a4b3e5ff2827060755"}, - {file = "watchfiles-1.0.0-cp310-none-win_amd64.whl", hash = "sha256:2de52b499e1ab037f1a87cb8ebcb04a819bf087b1015a4cf6dcf8af3c2a2613e"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fbd0ab7a9943bbddb87cbc2bf2f09317e74c77dc55b1f5657f81d04666c25269"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:774ef36b16b7198669ce655d4f75b4c3d370e7f1cbdfb997fb10ee98717e2058"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b4fb98100267e6a5ebaff6aaa5d20aea20240584647470be39fe4823012ac96"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0fc3bf0effa2d8075b70badfdd7fb839d7aa9cea650d17886982840d71fdeabf"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:648e2b6db53eca6ef31245805cd528a16f56fa4cc15aeec97795eaf713c11435"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa13d604fcb9417ae5f2e3de676e66aa97427d888e83662ad205bed35a313176"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:936f362e7ff28311b16f0b97ec51e8f2cc451763a3264640c6ed40fb252d1ee4"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245fab124b9faf58430da547512d91734858df13f2ddd48ecfa5e493455ffccb"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4ff9c7e84e8b644a8f985c42bcc81457240316f900fc72769aaedec9d088055a"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c9a8d8fd97defe935ef8dd53d562e68942ad65067cd1c54d6ed8a088b1d931d"}, - {file = "watchfiles-1.0.0-cp311-none-win32.whl", hash = "sha256:a0abf173975eb9dd17bb14c191ee79999e650997cc644562f91df06060610e62"}, - {file = "watchfiles-1.0.0-cp311-none-win_amd64.whl", hash = "sha256:2a825ba4b32c214e3855b536eb1a1f7b006511d8e64b8215aac06eb680642d84"}, - {file = "watchfiles-1.0.0-cp311-none-win_arm64.whl", hash = "sha256:a5a7a06cfc65e34fd0a765a7623c5ba14707a0870703888e51d3d67107589817"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:28fb64b5843d94e2c2483f7b024a1280662a44409bedee8f2f51439767e2d107"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e3750434c83b61abb3163b49c64b04180b85b4dabb29a294513faec57f2ffdb7"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bedf84835069f51c7b026b3ca04e2e747ea8ed0a77c72006172c72d28c9f69fc"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90004553be36427c3d06ec75b804233f8f816374165d5225b93abd94ba6e7234"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b46e15c34d4e401e976d6949ad3a74d244600d5c4b88c827a3fdf18691a46359"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:487d15927f1b0bd24e7df921913399bb1ab94424c386bea8b267754d698f8f0e"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ff236d7a3f4b0a42f699a22fc374ba526bc55048a70cbb299661158e1bb5e1f"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c01446626574561756067f00b37e6b09c8622b0fc1e9fdbc7cbcea328d4e514"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b551c465a59596f3d08170bd7e1c532c7260dd90ed8135778038e13c5d48aa81"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1ed613ee107269f66c2df631ec0fc8efddacface85314d392a4131abe299f00"}, - {file = "watchfiles-1.0.0-cp312-none-win32.whl", hash = "sha256:5f75cd42e7e2254117cf37ff0e68c5b3f36c14543756b2da621408349bd9ca7c"}, - {file = "watchfiles-1.0.0-cp312-none-win_amd64.whl", hash = "sha256:cf517701a4a872417f4e02a136e929537743461f9ec6cdb8184d9a04f4843545"}, - {file = "watchfiles-1.0.0-cp312-none-win_arm64.whl", hash = "sha256:8a2127cd68950787ee36753e6d401c8ea368f73beaeb8e54df5516a06d1ecd82"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:95de85c254f7fe8cbdf104731f7f87f7f73ae229493bebca3722583160e6b152"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:533a7cbfe700e09780bb31c06189e39c65f06c7f447326fee707fd02f9a6e945"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2218e78e2c6c07b1634a550095ac2a429026b2d5cbcd49a594f893f2bb8c936"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9122b8fdadc5b341315d255ab51d04893f417df4e6c1743b0aac8bf34e96e025"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9272fdbc0e9870dac3b505bce1466d386b4d8d6d2bacf405e603108d50446940"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3b33c3aefe9067ebd87846806cd5fc0b017ab70d628aaff077ab9abf4d06b3"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc338ce9f8846543d428260fa0f9a716626963148edc937d71055d01d81e1525"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ac778a460ea22d63c7e6fb0bc0f5b16780ff0b128f7f06e57aaec63bd339285"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53ae447f06f8f29f5ab40140f19abdab822387a7c426a369eb42184b021e97eb"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1f73c2147a453315d672c1ad907abe6d40324e34a185b51e15624bc793f93cc6"}, - {file = "watchfiles-1.0.0-cp313-none-win32.whl", hash = "sha256:eba98901a2eab909dbd79681190b9049acc650f6111fde1845484a4450761e98"}, - {file = "watchfiles-1.0.0-cp313-none-win_amd64.whl", hash = "sha256:d562a6114ddafb09c33246c6ace7effa71ca4b6a2324a47f4b09b6445ea78941"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3d94fd83ed54266d789f287472269c0def9120a2022674990bd24ad989ebd7a0"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48051d1c504448b2fcda71c5e6e3610ae45de6a0b8f5a43b961f250be4bdf5a8"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29cf884ad4285d23453c702ed03d689f9c0e865e3c85d20846d800d4787de00f"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d3572d4c34c4e9c33d25b3da47d9570d5122f8433b9ac6519dca49c2740d23cd"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c2696611182c85eb0e755b62b456f48debff484b7306b56f05478b843ca8ece"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:550109001920a993a4383b57229c717fa73627d2a4e8fcb7ed33c7f1cddb0c85"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b555a93c15bd2c71081922be746291d776d47521a00703163e5fbe6d2a402399"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947ccba18a38b85c366dafeac8df2f6176342d5992ca240a9d62588b214d731f"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ffd98a299b0a74d1b704ef0ed959efb753e656a4e0425c14e46ae4c3cbdd2919"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f8c4f3a1210ed099a99e6a710df4ff2f8069411059ffe30fa5f9467ebed1256b"}, - {file = "watchfiles-1.0.0-cp39-none-win32.whl", hash = "sha256:1e176b6b4119b3f369b2b4e003d53a226295ee862c0962e3afd5a1c15680b4e3"}, - {file = "watchfiles-1.0.0-cp39-none-win_amd64.whl", hash = "sha256:2d9c0518fabf4a3f373b0a94bb9e4ea7a1df18dec45e26a4d182aa8918dee855"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f159ac795785cde4899e0afa539f4c723fb5dd336ce5605bc909d34edd00b79b"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c3d258d78341d5d54c0c804a5b7faa66cd30ba50b2756a7161db07ce15363b8d"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bbd0311588c2de7f9ea5cf3922ccacfd0ec0c1922870a2be503cc7df1ca8be7"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a13ac46b545a7d0d50f7641eefe47d1597e7d1783a5d89e09d080e6dff44b0"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2bca898c1dc073912d3db7fa6926cc08be9575add9e84872de2c99c688bac4e"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d828fe2adc4ac8a64b875ca908b892a3603d596d43e18f7948f3fef5fc671c"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:074c7618cd6c807dc4eaa0982b4a9d3f8051cd0b72793511848fd64630174b17"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95dc785bc284552d044e561b8f4fe26d01ab5ca40d35852a6572d542adfeb4bc"}, - {file = "watchfiles-1.0.0.tar.gz", hash = "sha256:37566c844c9ce3b5deb964fe1a23378e575e74b114618d211fbda8f59d7b5dab"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1da46bb1eefb5a37a8fb6fd52ad5d14822d67c498d99bda8754222396164ae42"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2b961b86cd3973f5822826017cad7f5a75795168cb645c3a6b30c349094e02e3"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34e87c7b3464d02af87f1059fedda5484e43b153ef519e4085fe1a03dd94801e"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d9dd2b89a16cf7ab9c1170b5863e68de6bf83db51544875b25a5f05a7269e678"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b4691234d31686dca133c920f94e478b548a8e7c750f28dbbc2e4333e0d3da9"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:90b0fe1fcea9bd6e3084b44875e179b4adcc4057a3b81402658d0eb58c98edf8"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b90651b4cf9e158d01faa0833b073e2e37719264bcee3eac49fc3c74e7d304b"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2e9fe695ff151b42ab06501820f40d01310fbd58ba24da8923ace79cf6d702d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62691f1c0894b001c7cde1195c03b7801aaa794a837bd6eef24da87d1542838d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:275c1b0e942d335fccb6014d79267d1b9fa45b5ac0639c297f1e856f2f532552"}, + {file = "watchfiles-1.0.3-cp310-cp310-win32.whl", hash = "sha256:06ce08549e49ba69ccc36fc5659a3d0ff4e3a07d542b895b8a9013fcab46c2dc"}, + {file = "watchfiles-1.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f280b02827adc9d87f764972fbeb701cf5611f80b619c20568e1982a277d6146"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ffe709b1d0bc2e9921257569675674cafb3a5f8af689ab9f3f2b3f88775b960f"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:418c5ce332f74939ff60691e5293e27c206c8164ce2b8ce0d9abf013003fb7fe"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f492d2907263d6d0d52f897a68647195bc093dafed14508a8d6817973586b6b"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48c9f3bc90c556a854f4cab6a79c16974099ccfa3e3e150673d82d47a4bc92c9"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75d3bcfa90454dba8df12adc86b13b6d85fda97d90e708efc036c2760cc6ba44"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5691340f259b8f76b45fb31b98e594d46c36d1dc8285efa7975f7f50230c9093"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e263cc718545b7f897baeac1f00299ab6fabe3e18caaacacb0edf6d5f35513c"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c6cf7709ed3e55704cc06f6e835bf43c03bc8e3cb8ff946bf69a2e0a78d9d77"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:703aa5e50e465be901e0e0f9d5739add15e696d8c26c53bc6fc00eb65d7b9469"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bfcae6aecd9e0cb425f5145afee871465b98b75862e038d42fe91fd753ddd780"}, + {file = "watchfiles-1.0.3-cp311-cp311-win32.whl", hash = "sha256:6a76494d2c5311584f22416c5a87c1e2cb954ff9b5f0988027bc4ef2a8a67181"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:cf745cbfad6389c0e331786e5fe9ae3f06e9d9c2ce2432378e1267954793975c"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:2dcc3f60c445f8ce14156854a072ceb36b83807ed803d37fdea2a50e898635d6"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:93436ed550e429da007fbafb723e0769f25bae178fbb287a94cb4ccdf42d3af3"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c18f3502ad0737813c7dad70e3e1cc966cc147fbaeef47a09463bbffe70b0a00"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a5bc3ca468bb58a2ef50441f953e1f77b9a61bd1b8c347c8223403dc9b4ac9a"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0d1ec043f02ca04bf21b1b32cab155ce90c651aaf5540db8eb8ad7f7e645cba8"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f58d3bfafecf3d81c15d99fc0ecf4319e80ac712c77cf0ce2661c8cf8bf84066"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1df924ba82ae9e77340101c28d56cbaff2c991bd6fe8444a545d24075abb0a87"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:632a52dcaee44792d0965c17bdfe5dc0edad5b86d6a29e53d6ad4bf92dc0ff49"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bf4b459d94a0387617a1b499f314aa04d8a64b7a0747d15d425b8c8b151da0"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca94c85911601b097d53caeeec30201736ad69a93f30d15672b967558df02885"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65ab1fb635476f6170b07e8e21db0424de94877e4b76b7feabfe11f9a5fc12b5"}, + {file = "watchfiles-1.0.3-cp312-cp312-win32.whl", hash = "sha256:49bc1bc26abf4f32e132652f4b3bfeec77d8f8f62f57652703ef127e85a3e38d"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:48681c86f2cb08348631fed788a116c89c787fdf1e6381c5febafd782f6c3b44"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:9e080cf917b35b20c889225a13f290f2716748362f6071b859b60b8847a6aa43"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e153a690b7255c5ced17895394b4f109d5dcc2a4f35cb809374da50f0e5c456a"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac1be85fe43b4bf9a251978ce5c3bb30e1ada9784290441f5423a28633a958a7"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec98e31e1844eac860e70d9247db9d75440fc8f5f679c37d01914568d18721"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0179252846be03fa97d4d5f8233d1c620ef004855f0717712ae1c558f1974a16"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:995c374e86fa82126c03c5b4630c4e312327ecfe27761accb25b5e1d7ab50ec8"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29b9cb35b7f290db1c31fb2fdf8fc6d3730cfa4bca4b49761083307f441cac5a"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f8dc09ae69af50bead60783180f656ad96bd33ffbf6e7a6fce900f6d53b08f1"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:489b80812f52a8d8c7b0d10f0d956db0efed25df2821c7a934f6143f76938bd6"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:228e2247de583475d4cebf6b9af5dc9918abb99d1ef5ee737155bb39fb33f3c0"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1550be1a5cb3be08a3fb84636eaafa9b7119b70c71b0bed48726fd1d5aa9b868"}, + {file = "watchfiles-1.0.3-cp313-cp313-win32.whl", hash = "sha256:16db2d7e12f94818cbf16d4c8938e4d8aaecee23826344addfaaa671a1527b07"}, + {file = "watchfiles-1.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:160eff7d1267d7b025e983ca8460e8cc67b328284967cbe29c05f3c3163711a3"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c05b021f7b5aa333124f2a64d56e4cb9963b6efdf44e8d819152237bbd93ba15"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:310505ad305e30cb6c5f55945858cdbe0eb297fc57378f29bacceb534ac34199"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddff3f8b9fa24a60527c137c852d0d9a7da2a02cf2151650029fdc97c852c974"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46e86ed457c3486080a72bc837300dd200e18d08183f12b6ca63475ab64ed651"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f79fe7993e230a12172ce7d7c7db061f046f672f2b946431c81aff8f60b2758b"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea2b51c5f38bad812da2ec0cd7eec09d25f521a8b6b6843cbccedd9a1d8a5c15"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fe4e740ea94978b2b2ab308cbf9270a246bcbb44401f77cc8740348cbaeac3d"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9af037d3df7188ae21dc1c7624501f2f90d81be6550904e07869d8d0e6766655"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52bb50a4c4ca2a689fdba84ba8ecc6a4e6210f03b6af93181bb61c4ec3abaf86"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c14a07bdb475eb696f85c715dbd0f037918ccbb5248290448488a0b4ef201aad"}, + {file = "watchfiles-1.0.3-cp39-cp39-win32.whl", hash = "sha256:be37f9b1f8934cd9e7eccfcb5612af9fb728fecbe16248b082b709a9d1b348bf"}, + {file = "watchfiles-1.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:ef9ec8068cf23458dbf36a08e0c16f0a2df04b42a8827619646637be1769300a"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:84fac88278f42d61c519a6c75fb5296fd56710b05bbdcc74bdf85db409a03780"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c68be72b1666d93b266714f2d4092d78dc53bd11cf91ed5a3c16527587a52e29"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889a37e2acf43c377b5124166bece139b4c731b61492ab22e64d371cce0e6e80"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca05cacf2e5c4a97d02a2878a24020daca21dbb8823b023b978210a75c79098"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:8af4b582d5fc1b8465d1d2483e5e7b880cc1a4e99f6ff65c23d64d070867ac58"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:127de3883bdb29dbd3b21f63126bb8fa6e773b74eaef46521025a9ce390e1073"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713f67132346bdcb4c12df185c30cf04bdf4bf6ea3acbc3ace0912cab6b7cb8c"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abd85de513eb83f5ec153a802348e7a5baa4588b818043848247e3e8986094e8"}, + {file = "watchfiles-1.0.3.tar.gz", hash = "sha256:f3ff7da165c99a5412fe5dd2304dd2dbaaaa5da718aad942dcb3a178eaa70c56"}, ] [package.dependencies] @@ -11052,4 +11173,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "1aa6a44bc9270d50c9c0ea09f55a304b5148bf4dbbbb068ff1b1ea8da6fa60cc" +content-hash = "f4accd01805cbf080c4c5295f97a06c8e4faec7365d2c43d0435e56b46461732" diff --git a/api/pyproject.toml b/api/pyproject.toml index a20c129e9c..28e0305406 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -60,13 +60,14 @@ oci = "~2.135.1" openai = "~1.52.0" openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } +pandas-stubs = "~2.2.3.241009" psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.9.2" pydantic-settings = "~2.6.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" -pypdfium2 = "~4.17.0" +pypdfium2 = "~4.30.0" python = ">=3.11,<3.13" python-docx = "~1.1.0" python-dotenv = "1.0.0" @@ -84,6 +85,7 @@ tencentcloud-sdk-python-hunyuan = "~3.0.1158" tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" +types-pytz = "~2024.2.0.20241003" unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} @@ -173,6 +175,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" faker = "~32.1.0" +mypy = "~1.13.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 97e5c77e95..48bdc872f4 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -32,8 +32,9 @@ def clean_messages(): while True: try: # Main query with join and filter + # FIXME:for mypy no paginate method error messages = ( - db.session.query(Message) + db.session.query(Message) # type: ignore .filter(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index e12be649e4..f66b3c4797 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -52,8 +52,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -120,8 +119,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index a20b500308..1c985461c6 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -36,14 +36,15 @@ def create_tidb_serverless_task(): def create_clusters(batch_size): try: + # TODO: maybe we can set the default value for the following parameters in the config file new_clusters = TidbService.batch_create_tidb_serverless_cluster( - batch_size, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + batch_size=batch_size, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + region=dify_config.TIDB_REGION or "", ) for new_cluster in new_clusters: tidb_auth_binding = TidbAuthBinding( diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index b2d8746f9c..11a39e60ee 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -36,13 +36,14 @@ def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): # batch 20 for i in range(0, len(tidb_serverless_list), 20): items = tidb_serverless_list[i : i + 20] + # TODO: maybe we can set the default value for the following parameters in the config file TidbService.batch_update_tidb_serverless_cluster_status( - items, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, + tidb_serverless_list=items, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", ) except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/account_service.py b/api/services/account_service.py index 7613f48a3e..13b70db580 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,7 +6,7 @@ import secrets import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import BaseModel from sqlalchemy import func @@ -120,7 +120,7 @@ class AccountService: account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return account + return cast(Account, account) @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -133,7 +133,7 @@ class AccountService: "sub": "Console API Passport", } - token = PassportService().issue(payload) + token: str = PassportService().issue(payload) return token @staticmethod @@ -165,7 +165,7 @@ class AccountService: db.session.commit() - return account + return cast(Account, account) @staticmethod def update_account_password(account, password, new_password): @@ -348,6 +348,8 @@ class AccountService: language: Optional[str] = "en-US", ): account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError @@ -378,6 +380,8 @@ class AccountService: def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" ): + if email is None: + raise ValueError("Email must be provided.") if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError @@ -421,7 +425,7 @@ class AccountService: if count is None: count = 0 count = int(count) + 1 - redis_client.setex(key, 60 * 60 * 24, count) + redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count) @staticmethod def is_login_error_rate_limit(email: str) -> bool: @@ -670,7 +674,7 @@ class TenantService: @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return db.session.query(func.count(Tenant.id)).scalar() + return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: @@ -734,10 +738,10 @@ class TenantService: db.session.commit() @staticmethod - def get_custom_config(tenant_id: str) -> None: - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404() + def get_custom_config(tenant_id: str) -> dict: + tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() - return tenant.custom_config_dict + return cast(dict, tenant.custom_config_dict) class RegisterService: @@ -808,7 +812,7 @@ class RegisterService: account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - if open_id is not None or provider is not None: + if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) if FeatureService.get_system_features().is_allow_create_workspace: @@ -896,7 +900,9 @@ class RegisterService: redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid( + cls, workspace_id: Optional[str], email: str, token: str + ) -> Optional[dict[str, Any]]: invitation_data = cls._get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -955,7 +961,7 @@ class RegisterService: if not data: return None - invitation = json.loads(data) + invitation: dict = json.loads(data) return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d2cd7bea67..6dc1affa11 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -48,6 +48,8 @@ class AdvancedPromptTemplateService: return cls.get_chat_prompt( copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt ) + # default return empty dict + return {} @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: @@ -91,3 +93,5 @@ class AdvancedPromptTemplateService: return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) + # default return empty dict + return {} diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 762760b168..bb5a1892a4 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,7 +1,8 @@ import threading +from typing import Optional import pytz -from flask_login import current_user +from flask_login import current_user # type: ignore import contexts from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager @@ -33,7 +34,7 @@ class AgentService: if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = ( + message: Optional[Message] = ( db.session.query(Message) .filter( Message.id == message_id, diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f45c21cb18..a946405c95 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,8 +1,9 @@ import datetime import uuid +from typing import cast import pandas as pd -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import or_ from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -71,7 +72,7 @@ class AppAnnotationService: app_id, annotation_setting.collection_binding_id, ) - return annotation + return cast(MessageAnnotation, annotation) @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: @@ -124,8 +125,7 @@ class AppAnnotationService: raise NotFound("App not found") if keyword: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .filter( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), @@ -137,8 +137,7 @@ class AppAnnotationService: ) else: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) @@ -327,8 +326,7 @@ class AppAnnotationService: raise NotFound("Annotation not found") annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .filter( + AppAnnotationHitHistory.query.filter( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 63496d3df3..7793fdc4ff 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -1,7 +1,7 @@ import logging import uuid from enum import StrEnum -from typing import Optional +from typing import Optional, cast from uuid import uuid4 import yaml @@ -33,8 +33,8 @@ logger = logging.getLogger(__name__) IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes -CURRENT_DSL_VERSION = "0.1.4" DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB +CURRENT_DSL_VERSION = "0.1.5" class ImportMode(StrEnum): @@ -124,7 +124,7 @@ class AppDslService: raise ValueError(f"Invalid import_mode: {import_mode}") # Get YAML content - content = "" + content: bytes | str = b"" if mode == ImportMode.YAML_URL: if not yaml_url: return Import( @@ -156,7 +156,7 @@ class AppDslService: ) try: - content = content.decode("utf-8") + content = cast(bytes, content).decode("utf-8") except UnicodeDecodeError as e: return Import( id=import_id, @@ -391,7 +391,10 @@ class AppDslService: ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) - app_mode = AppMode(app_data["mode"]) + app_mode = app_data.get("mode") + if not app_mode: + raise ValueError("loss app mode") + app_mode = AppMode(app_mode) # Set icon type icon_type_value = icon_type or app_data.get("icon_type") @@ -410,6 +413,9 @@ class AppDslService: app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + # Create new app app = App() app.id = str(uuid4()) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index a1ada9670d..39b4afa252 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -114,7 +114,7 @@ class AppGenerateService: @staticmethod def _get_max_active_requests(app_model: App) -> int: max_active_requests = app_model.max_active_requests - if app_model.max_active_requests is None: + if max_active_requests is None: max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) return max_active_requests diff --git a/api/services/app_service.py b/api/services/app_service.py index 8d8ba735ec..41c15bbf0a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,9 +1,9 @@ import json import logging from datetime import UTC, datetime -from typing import cast +from typing import Optional, cast -from flask_login import current_user +from flask_login import current_user # type: ignore from flask_sqlalchemy.pagination import Pagination from configs import dify_config @@ -83,7 +83,7 @@ class AppService: # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -100,6 +100,8 @@ class AppService: else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for model {model_instance.model}") default_model_dict = { "provider": model_instance.provider, @@ -109,7 +111,7 @@ class AppService: } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) default_model_config["model"]["provider"] = provider default_model_config["model"]["name"] = model @@ -314,7 +316,7 @@ class AppService: """ app_mode = AppMode.value_of(app_model.mode) - meta = {"tool_icons": {}} + meta: dict = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow @@ -336,7 +338,7 @@ class AppService: } ) else: - app_model_config: AppModelConfig = app_model.app_model_config + app_model_config: Optional[AppModelConfig] = app_model.app_model_config if not app_model_config: return meta @@ -352,16 +354,18 @@ class AppService: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get("provider_type") - provider_id = tool.get("provider_id") - tool_name = tool.get("tool_name") + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + tool_name = tool.get("tool_name", "") if provider_type == "builtin": meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() ) + if provider is None: + raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) except: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a0cd5725b..973110f515 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -110,6 +110,8 @@ class AudioService: voices = model_instance.get_tts_voices() if voices: voice = voices[0].get("value") + if not voice: + raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") @@ -121,6 +123,8 @@ class AudioService: if message_id: message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + return None if message.answer == "" and message.status == "normal": return None @@ -130,6 +134,8 @@ class AudioService: return Response(stream_with_context(response), content_type="audio/mpeg") return response else: + if not text: + raise ValueError("Text is required") response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): return Response(stream_with_context(response), content_type="audio/mpeg") diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index afc491398f..50e4edff14 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -11,8 +11,8 @@ class FirecrawlAuth(ApiKeyAuthBase): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) - self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") + self.api_key = credentials.get("config", {}).get("api_key", None) + self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index de898a1f94..6100e9afc8 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -11,7 +11,7 @@ class JinaAuth(ApiKeyAuthBase): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index de898a1f94..6100e9afc8 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -11,7 +11,7 @@ class JinaAuth(ApiKeyAuthBase): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 911d234641..d980186488 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,6 +1,8 @@ import os +from typing import Optional -import requests +import httpx +from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed from extensions.ext_database import db from models.account import TenantAccountJoin, TenantAccountRole @@ -39,11 +41,17 @@ class BillingService: return cls._send_request("GET", "/invoices", params=params) @classmethod + @retry( + wait=wait_fixed(2), + stop=stop_before_delay(10), + retry=retry_if_not_exception_type(httpx.RequestError), + reraise=True, + ) def _send_request(cls, method, endpoint, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers) + response = httpx.request(method, url, json=json, params=params, headers=headers) return response.json() @@ -51,11 +59,14 @@ class BillingService: def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = ( + join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) + if not join: + raise ValueError("Tenant account join not found") + if not TenantAccountRole.is_privileged_role(join.role): raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 8642972710..6485cbf37d 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import UTC, datetime from typing import Optional, Union -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, func, or_, select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -18,19 +19,21 @@ class ConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, + include_ids: Optional[Sequence[str]] = None, + exclude_ids: Optional[Sequence[str]] = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Conversation).filter( + stmt = select(Conversation).where( Conversation.is_deleted == False, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -38,37 +41,39 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - if include_ids is not None: - base_query = base_query.filter(Conversation.id.in_(include_ids)) - + stmt = stmt.where(Conversation.id.in_(include_ids)) if exclude_ids is not None: - base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) if last_id: - last_conversation = base_query.filter(Conversation.id == last_id).first() + last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) if not last_conversation: raise LastConversationNotExistsError() # build filters based on sorting - filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) - base_query = base_query.filter(filter_condition) - - base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) - - conversations = base_query.limit(limit).all() + filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=last_conversation, + ) + stmt = stmt.where(filter_condition) + query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) + conversations = session.scalars(query_stmt).all() has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] rest_filter_condition = cls._build_filter_condition( - sort_field, sort_direction, current_page_last_conversation, is_next_page=True + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=current_page_last_conversation, ) - rest_count = base_query.filter(rest_filter_condition).count() - + count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery()) + rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True @@ -81,11 +86,9 @@ class ConversationService: return sort_by, asc @classmethod - def _build_filter_condition( - cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False - ): + def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + if sort_direction == desc: return getattr(Conversation, sort_field) < field_value else: return getattr(Conversation, sort_field) > field_value diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1686f9fb87..ca741f1935 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import time import uuid from typing import Any, Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -186,8 +186,9 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset: - return Dataset.query.filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Optional[Dataset]: + dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + return dataset @staticmethod def check_dataset_model_setting(dataset): @@ -228,14 +229,20 @@ class DatasetService: @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found") DatasetService.check_dataset_permission(dataset, user) if dataset.provider == "external": - dataset.retrieval_model = data.get("external_retrieval_model", None) + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", "") + permission = data.get("permission") + if permission: + dataset.permission = permission external_knowledge_id = data.get("external_knowledge_id", None) - dataset.permission = data.get("permission") db.session.add(dataset) if not external_knowledge_id: raise ValueError("External knowledge id is required.") @@ -367,7 +374,13 @@ class DatasetService: raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): + def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + if not dataset: + raise ValueError("Dataset not found") + + if not user: + raise ValueError("User not found") + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: raise NoPermissionError("You do not have permission to access this dataset.") @@ -761,6 +774,11 @@ class DocumentService: rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" + ) + return db.session.add(dataset_process_rule) db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) @@ -1005,9 +1023,10 @@ class DocumentService: rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) - db.session.add(dataset_process_rule) - db.session.commit() - document.dataset_process_rule_id = dataset_process_rule.id + if dataset_process_rule is not None: + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id # update document data source if document_data.get("data_source"): file_name = "" @@ -1553,7 +1572,7 @@ class SegmentService: segment.word_count = len(content) if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change if segment_update_entity.keywords: segment.keywords = segment_update_entity.keywords @@ -1568,7 +1587,8 @@ class SegmentService: db.session.add(document) # update segment index task if segment_update_entity.enabled: - VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) + keywords = segment_update_entity.keywords or [] + VectorService.create_segments_vector([keywords], [segment], dataset) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1600,7 +1620,7 @@ class SegmentService: segment.disabled_by = None if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1618,8 +1638,8 @@ class SegmentService: segment.status = "error" segment.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() - return segment + new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + return new_segment @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -1679,6 +1699,8 @@ class DatasetCollectionBindingService: .order_by(DatasetCollectionBinding.created_at) .first() ) + if not dataset_collection_binding: + raise ValueError("Dataset collection binding not found") return dataset_collection_binding diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 92098f06cc..3c3f970444 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -8,8 +8,8 @@ class EnterpriseRequest: secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") proxies = { - "http": None, - "https": None, + "http": "", + "https": "", } @classmethod diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 7d0d442776..91875dd2f2 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -6,6 +6,9 @@ from pydantic import BaseModel, ConfigDict from configs import dify_config from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration +from core.entities.model_entities import ( + SimpleModelProviderEntity, +) from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ( @@ -158,7 +161,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleProviderEntityResponse + provider: SimpleModelProviderEntity def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None: dump_model = model.model_dump() diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 7be20301a7..898624066b 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from datetime import UTC, datetime -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import httpx import validators @@ -45,7 +45,10 @@ class ExternalDatasetService: @staticmethod def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: - ExternalDatasetService.check_endpoint_and_api_key(args.get("settings")) + settings = args.get("settings") + if settings is None: + raise ValueError("settings is required") + ExternalDatasetService.check_endpoint_and_api_key(settings) external_knowledge_api = ExternalKnowledgeApis( tenant_id=tenant_id, created_by=user_id, @@ -86,11 +89,16 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( id=external_knowledge_api_id, tenant_id=tenant_id ).first() if external_knowledge_api is None: @@ -127,7 +135,7 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id ).first() if not external_knowledge_binding: @@ -163,8 +171,9 @@ class ExternalDatasetService: "follow_redirects": True, } - response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) - + response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( + data=json.dumps(settings.params), files=files, **kwargs + ) return response @staticmethod @@ -265,15 +274,15 @@ class ExternalDatasetService: "knowledge_id": external_knowledge_binding.external_knowledge_id, } - external_knowledge_api_setting = { - "url": f"{settings.get('endpoint')}/retrieval", - "request_method": "post", - "headers": headers, - "params": request_params, - } response = ExternalDatasetService.process_external_api( - ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, ) if response.status_code == 200: - return response.json().get("records", []) + return cast(list[Any], response.json().get("records", [])) return [] diff --git a/api/services/feature_service.py b/api/services/feature_service.py index fbdb3e8b83..36c79d7045 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -65,6 +65,7 @@ class SystemFeatureModel(BaseModel): enable_social_oauth_login: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False + is_email_setup: bool = False license: LicenseModel = LicenseModel() @@ -103,6 +104,7 @@ class FeatureService: system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): diff --git a/api/services/file_service.py b/api/services/file_service.py index b12b95ca13..d417e81734 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,7 +3,7 @@ import hashlib import uuid from typing import Any, Literal, Union -from flask_login import current_user +from flask_login import current_user # type: ignore from werkzeug.exceptions import NotFound from configs import dify_config @@ -61,14 +61,14 @@ class FileService: # end_user current_tenant_id = user.tenant_id - file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension + file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, content) # save file to db upload_file = UploadFile( - tenant_id=current_tenant_id, + tenant_id=current_tenant_id or "", storage_type=dify_config.STORAGE_TYPE, key=file_key, name=filename, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7957b4dc82..41b4e1ec46 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,5 +1,6 @@ import logging import time +from typing import Any from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document @@ -24,7 +25,7 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - retrieval_model: dict, + retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, limit: int = 10, ) -> dict: @@ -68,7 +69,7 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_retrieve_response(dataset, query, all_documents)) @classmethod def external_retrieve( @@ -102,13 +103,16 @@ class HitTestingService: db.session.add(dataset_query) db.session.commit() - return cls.compact_external_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): records = [] for document in documents: + if document.metadata is None: + continue + index_node_id = document.metadata["doc_id"] segment = ( @@ -140,7 +144,7 @@ class HitTestingService: } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: records = [] if dataset.provider == "external": for document in documents: @@ -152,11 +156,10 @@ class HitTestingService: } records.append(record) return { - "query": { - "content": query, - }, + "query": {"content": query}, "records": records, } + return {"query": {"content": query}, "records": []} @classmethod def hit_testing_args_check(cls, args): diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc..8df1a6ba14 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 +import boto3 # type: ignore from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index f432a77c80..c4447a84da 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -151,8 +151,13 @@ class MessageService: @classmethod def create_feedback( - cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] - ) -> MessageFeedback: + cls, + app_model: App, + message_id: str, + user: Optional[Union[Account, EndUser]], + rating: Optional[str], + content: Optional[str], + ): if not user: raise ValueError("user cannot be None") @@ -164,6 +169,7 @@ class MessageService: db.session.delete(feedback) elif rating and feedback: feedback.rating = rating + feedback.content = content elif not rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -172,6 +178,7 @@ class MessageService: conversation_id=message.conversation_id, message_id=message.id, rating=rating, + content=content, from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), @@ -257,6 +264,8 @@ class MessageService: ) app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if not app_model_config: + raise ValueError("did not find app model config") suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict if suggested_questions_after_answer.get("enabled", False) is False: @@ -278,7 +287,7 @@ class MessageService: ) with measure_time() as timer: - questions = LLMGenerator.generate_suggested_questions_after_answer( + questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 1d9a736540..26311a6377 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,7 +2,7 @@ import datetime import json import logging from json import JSONDecodeError -from typing import Optional +from typing import Optional, Union from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -88,11 +88,11 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get provider model setting provider_model_setting = provider_configuration.get_provider_model_setting( - model_type=model_type, + model_type=model_type_enum, model=model, ) @@ -106,7 +106,7 @@ class ModelLoadBalancingService: .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .order_by(LoadBalancingModelConfig.created_at) @@ -124,7 +124,7 @@ class ModelLoadBalancingService: if not inherit_config_exists: # Initialize the inherit configuration - inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) # prepend the inherit configuration load_balancing_configs.insert(0, inherit_config) @@ -148,7 +148,7 @@ class ModelLoadBalancingService: tenant_id=tenant_id, provider=provider, model=model, - model_type=model_type, + model_type=model_type_enum, config_id=load_balancing_config.id, ) @@ -214,7 +214,7 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get load balancing configurations load_balancing_model_config = ( @@ -222,7 +222,7 @@ class ModelLoadBalancingService: .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -300,7 +300,7 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") @@ -310,7 +310,7 @@ class ModelLoadBalancingService: .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .all() @@ -359,7 +359,7 @@ class ModelLoadBalancingService: credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, @@ -395,7 +395,7 @@ class ModelLoadBalancingService: credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, validate=False, @@ -405,7 +405,7 @@ class ModelLoadBalancingService: load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type_enum.to_origin_model_type(), model_name=model, name=name, encrypted_config=json.dumps(credentials), @@ -450,7 +450,7 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) load_balancing_model_config = None if config_id: @@ -460,7 +460,7 @@ class ModelLoadBalancingService: .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -474,7 +474,7 @@ class ModelLoadBalancingService: self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, @@ -548,19 +548,14 @@ class ModelLoadBalancingService: def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> ModelCredentialSchema | ProviderCredentialSchema: - """ - Get form schemas. - :param provider_configuration: provider configuration - :return: - """ - # Get credential form schemas from model credential schema or provider credential schema + ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + """Get form schemas.""" if provider_configuration.provider.model_credential_schema: - credential_schema = provider_configuration.provider.model_credential_schema + return provider_configuration.provider.model_credential_schema + elif provider_configuration.provider.provider_credential_schema: + return provider_configuration.provider.provider_credential_schema else: - credential_schema = provider_configuration.provider.provider_credential_schema - - return credential_schema + raise ValueError("No credential schema found") def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 589af9d87e..0a0a5619e1 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager @@ -98,20 +98,12 @@ class ModelProviderService: def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: """ get provider credentials. - - :param tenant_id: - :param provider: - :return: """ - # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Get provider custom credentials from workspace return provider_configuration.get_custom_credentials(obfuscated=True) def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: @@ -282,7 +274,7 @@ class ModelProviderService: models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider - provider_models = {} + provider_models: dict[str, list[ModelWithProviderEntity]] = {} for model in models: if model.provider.provider not in provider_models: provider_models[model.provider.provider] = [] diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index dfb21e767f..082afeed89 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db from models.model import App, AppModelConfig @@ -5,7 +7,7 @@ from models.model import App, AppModelConfig class ModerationService: def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: - app_model_config: AppModelConfig = None + app_model_config: Optional[AppModelConfig] = None app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1160a1f275..fc1e08518b 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -12,7 +14,7 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -22,7 +24,10 @@ class OpsService: return None # decrypt_token and obfuscated_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -73,8 +78,9 @@ class OpsService: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["other_keys"], ) - default_config_instance = config_class(**tracing_config) - for key in other_keys: + # FIXME: ignore type error + default_config_instance = config_class(**tracing_config) # type: ignore + for key in other_keys: # type: ignore if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) @@ -92,7 +98,7 @@ class OpsService: project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -102,7 +108,10 @@ class OpsService: return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) if project_url: tracing_config["project_url"] = project_url @@ -139,7 +148,10 @@ class OpsService: return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config( tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config ) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 4704d533a9..523aebeed5 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -41,7 +41,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") ) - return cls.builtin_data + return cls.builtin_data or {} @classmethod def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: @@ -50,8 +50,8 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): :param language: language :return: """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("recommended_apps", {}).get(language) + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: @@ -60,5 +60,5 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): :param app_id: App ID :return: """ - builtin_data = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 995d3755bb..3295516cce 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -57,13 +57,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): recommended_app_result = { "id": recommended_app.id, - "app": { - "id": app.id, - "name": app.name, - "mode": app.mode, - "icon": app.icon, - "icon_background": app.icon_background, - }, + "app": recommended_app.app, "app_id": recommended_app.app_id, "description": site.description, "copyright": site.copyright, diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b0607a2132..80e1aefc01 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -47,8 +47,8 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None - - return response.json() + data: dict = response.json() + return data @classmethod def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -63,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result = response.json() + result: dict = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 4660316fcf..54c5845515 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -33,5 +33,5 @@ class RecommendedAppService: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result = retrieval_instance.get_recommend_app_detail(app_id) + result: dict = retrieval_instance.get_recommend_app_detail(app_id) return result diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 9fe3cecce7..4cb8700117 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -13,6 +13,8 @@ class SavedMessageService: def pagination_by_last_id( cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) .filter( @@ -31,6 +33,8 @@ class SavedMessageService: @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( @@ -59,6 +63,8 @@ class SavedMessageService: @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a374bdcf00..9600601633 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,7 +1,7 @@ import uuid from typing import Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -21,7 +21,7 @@ class TagService: if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id) - results = query.order_by(Tag.created_at.desc()).all() + results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 0e70b5e94d..6f848d49c4 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,5 +1,7 @@ import json import logging +from collections.abc import Mapping +from typing import Any, cast from httpx import get @@ -27,12 +29,12 @@ logger = logging.getLogger(__name__) class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> list[ApiToolBundle]: + def parser_api_schema(schema: str) -> Mapping[str, Any]: """ parse api schema to tool bundle """ try: - warnings = {} + warnings: dict[str, str] = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: @@ -67,13 +69,16 @@ class ApiToolManageService: ), ] - return jsonable_encoder( - { - "schema_type": schema_type, - "parameters_schema": tool_bundles, - "credentials_schema": credentials_schema, - "warning": warnings, - } + return cast( + Mapping, + jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ), ) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") @@ -125,7 +130,7 @@ class ApiToolManageService: raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -196,7 +201,7 @@ class ApiToolManageService: # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) - except Exception as e: + except Exception: logger.exception("parse api schema error") raise ValueError("invalid schema, please check the url you provided") @@ -265,9 +270,8 @@ class ApiToolManageService: if provider is None: raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -368,7 +372,7 @@ class ApiToolManageService: try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) - except Exception as e: + except Exception: raise ValueError("invalid schema") # get tool bundle @@ -467,7 +471,7 @@ class ApiToolManageService: tools = provider_controller.get_tools(tenant_id=tenant_id) - for tool in tools: + for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index c78ef6b894..c40d05d2cc 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,8 @@ import json import logging from pathlib import Path +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -22,7 +24,7 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]: + def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ list builtin tool provider tools @@ -50,8 +52,8 @@ class BuiltinToolManageService: credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt(credentials) - result = [] - for tool in tools: + result: list[ToolApiEntity] = [] + for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, @@ -107,7 +109,9 @@ class BuiltinToolManageService: return jsonable_encoder(provider.get_credentials_schema()) @staticmethod - def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): """ update builtin tool provider """ @@ -155,13 +159,10 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) - db.session.commit() + session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() # delete cache tool_configuration.delete_tool_credentials_cache() @@ -169,11 +170,11 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ get builtin tool provider credentials """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) + provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) if provider_obj is None: return {} @@ -245,9 +246,8 @@ class BuiltinToolManageService: db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}" # find provider - find_provider = lambda provider: next( - filter(lambda db_provider: db_provider.provider == provider, db_providers), None - ) + def find_provider(provider): + return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) result: list[ToolProviderApiEntity] = [] @@ -261,6 +261,8 @@ class BuiltinToolManageService: name_func=lambda x: x.identity.name, ): continue + if provider_controller.identity is None: + continue # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( @@ -273,7 +275,7 @@ class BuiltinToolManageService: ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) tools = provider_controller.get_tools() - for tool in tools: + for tool in tools or []: user_builtin_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b532397813..b5565e986d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -46,7 +46,7 @@ class ToolTransformService: if isinstance(icon, str): return json.loads(icon) return icon - except: + except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} return "" @@ -191,6 +191,8 @@ class ToolTransformService: convert provider controller to user provider """ username = "Anonymous" + if db_provider.user is None: + raise ValueError(f"user is None for api provider {db_provider.id}") try: user = db_provider.user if not user: diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index f3e5ac0503..dc7d4a858c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,11 +1,12 @@ import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Optional from sqlalchemy import or_ from core.model_runtime.utils.encoders import jsonable_encoder +from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils @@ -32,7 +33,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: Mapping[str, Any], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: list[str] | None = None, ) -> dict: @@ -98,7 +99,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[dict], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: list[str] | None = None, ) -> dict: @@ -190,11 +191,11 @@ class WorkflowToolManageService: for provider in db_tools: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) - except: + except Exception: # skip deleted tools pass - labels = ToolLabelManager.get_tools_labels(tools) + labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) result = [] @@ -284,6 +285,9 @@ class WorkflowToolManageService: raise ValueError("Workflow not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {db_tool.id} not found") return { "name": db_tool.name, @@ -321,6 +325,9 @@ class WorkflowToolManageService: raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") return [ ToolTransformService.convert_tool_entity_to_api_entity( diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index d7ccc964cb..f698ed3084 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,5 +1,8 @@ from typing import Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -13,6 +16,8 @@ class WebConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], @@ -21,26 +26,29 @@ class WebConversationService: pinned: Optional[bool] = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") include_ids = None exclude_ids = None - if pinned is not None: - pinned_conversations = ( - db.session.query(PinnedConversation) - .filter( + if pinned is not None and user: + stmt = ( + select(PinnedConversation.conversation_id) + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) .order_by(PinnedConversation.created_at.desc()) - .all() ) - pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] + pinned_conversation_ids = session.scalars(stmt).all() + if pinned: include_ids = pinned_conversation_ids else: exclude_ids = pinned_conversation_ids return ConversationService.pagination_by_last_id( + session=session, app_model=app_model, user=user, last_id=last_id, @@ -53,6 +61,8 @@ class WebConversationService: @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( @@ -83,6 +93,8 @@ class WebConversationService: @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( diff --git a/api/services/website_service.py b/api/services/website_service.py index 230f5d7815..1ad7d0399d 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,8 +1,9 @@ import datetime import json +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from core.helper import encrypter from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp @@ -23,9 +24,9 @@ class WebsiteService: @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider") + provider = args.get("provider", "") url = args.get("url") - options = args.get("options") + options = args.get("options", "") credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key @@ -164,16 +165,18 @@ class WebsiteService: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later + data: Any if provider == "firecrawl": file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) + d = storage.load_once(file_key) + if d: + data = json.loads(d.decode("utf-8")) else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) @@ -183,22 +186,17 @@ class WebsiteService: if data: for item in data: if item.get("source_url") == url: - return item + return dict(item) return None elif provider == "jinareader": - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) - elif not job_id: + if not job_id: response = requests.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) if response.json().get("code") != 200: raise ValueError("Failed to crawl") - return response.json().get("data") + return dict(response.json().get("data", {})) else: api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) response = requests.post( @@ -218,12 +216,13 @@ class WebsiteService: data = response.json().get("data", {}) for item in data.get("processed", {}).values(): if item.get("data", {}).get("url") == url: - return item.get("data", {}) + return dict(item.get("data", {})) + return None else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 90b5cc4836..2b0d57bdfd 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional from core.app.app_config.entities import ( DatasetEntity, @@ -101,7 +101,7 @@ class WorkflowConverter: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = {"nodes": [], "edges": []} + graph: dict[str, Any] = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -118,7 +118,7 @@ class WorkflowConverter: graph["nodes"].append(start_node) # convert to http request node - external_data_variable_node_mapping = {} + external_data_variable_node_mapping: dict[str, str] = {} if app_config.external_data_variables: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, @@ -199,15 +199,16 @@ class WorkflowConverter: return workflow def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_mode_enum = AppMode.value_of(app_model.mode) + app_config: EasyUIBasedAppConfig + if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) - elif app_mode == AppMode.CHAT: + elif app_mode_enum == AppMode.CHAT: app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode == AppMode.COMPLETION: + elif app_mode_enum == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -302,7 +303,7 @@ class WorkflowConverter: nodes.append(http_request_node) # append code node for response body parsing - code_node = { + code_node: dict[str, Any] = { "id": f"code_{index}", "position": None, "data": { @@ -401,6 +402,7 @@ class WorkflowConverter: ) role_prefix = None + prompts: Any = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 8cfa29ee4d..80b29814ac 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,6 +1,8 @@ import threading import contexts +from typing import Optional + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -95,7 +97,7 @@ class WorkflowRunService: return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 502915f468..95649106e2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -298,7 +298,7 @@ class WorkflowService: raise ValueError("Node run failed with no run result") # single step debug mode error handling return if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: - node_error_args = { + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, @@ -388,7 +388,7 @@ class WorkflowService: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow - new_app = workflow_converter.convert_to_workflow( + new_app: App = workflow_converter.convert_to_workflow( app_model=app_model, account=account, name=args.get("name", "Default Name"), diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 8fcb12b1cb..7637b31454 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,4 @@ -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from extensions.ext_database import db @@ -29,6 +29,7 @@ class WorkspaceService: .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) + assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py b/api/tasks/__init__.py similarity index 100% rename from api/tests/integration_tests/model_runtime/vessl_ai/__init__.py rename to api/tasks/__init__.py diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 09be661216..50bb2b6e63 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 25c55bcfaf..aab21a4410 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fa7e5ac919..06162b02d6 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f0f6b32b06..a6a598ce4b 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index a2f4913513..26bf1c7c9f 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 0bdcd0eccd..b42af0c7fa 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index b685d84d07..8c675feaa6 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 41e1419d25..ce3d65526c 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -4,7 +4,7 @@ import time import uuid import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import func from core.indexing_runner import IndexingRunner @@ -106,6 +106,6 @@ def batch_create_segment_to_index_task( logging.info( click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") ) - except Exception as e: + except Exception: logging.exception("Segments batch created index failed") redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 7548efad6a..c48eb2e320 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.rag_web_reader import get_image_upload_file_ids @@ -71,6 +71,8 @@ def clean_dataset_task( image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 4a3dd75889..05eb9fd625 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,7 +3,7 @@ import time from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.rag_web_reader import get_image_upload_file_ids @@ -44,6 +44,8 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 75d9e03130..f5d6406d9c 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 315b01f157..dfa053a43c 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,7 +4,7 @@ import time from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index cfc54920e2..b025509aeb 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index c3e0ea5d9f..45a612c745 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 15e1e50076..f30a1cc7ac 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 1831691393..ac4e81f95d 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 734dd2478a..21b571b6cb 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 1a52a6636b..5f1e9a892f 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index f4c3dbd2e2..6db2620eb6 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -26,6 +26,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") # check document limit features = FeatureService.get_features(dataset.tenant_id) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 12639db939..2f6eb7b82a 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index d78fc2b891..5dc935548f 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index c7dfb9bf60..3094527fd4 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from configs import dify_config diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 8596ca07cf..d5be94431b 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 34c62dc923..bb3b9e17ea 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task +from celery import shared_task # type: ignore from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 934eb7430c..b603d689ba 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 66f78636ec..c3910e2be3 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,7 +3,7 @@ import time from collections.abc import Callable import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 1909eaf341..4ba6d1a83e 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 73471fd6e7..485caa5152 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -22,10 +22,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): Usage: retry_document_indexing_task.delay(dataset_id, document_id) """ - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + for document_id in document_ids: retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit @@ -55,29 +58,31 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): document = ( db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 1d2a338c83..5d6b069cf4 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,7 +3,7 @@ import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -25,6 +25,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit @@ -52,29 +54,31 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py index 64f2884c4b..57fba31763 100644 --- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -1,6 +1,6 @@ from typing import Any -import toml +import toml # type: ignore def load_api_poetry_configs() -> dict[str, Any]: @@ -38,7 +38,7 @@ def test_group_dependencies_version_operator(): ) -def test_duplicated_dependency_crossing_groups(): +def test_duplicated_dependency_crossing_groups() -> None: all_dependency_names: list[str] = [] for dependencies in load_all_dependency_groups().values(): dependency_names = list(dependencies.keys()) diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py index 6371694694..5e3ee6bedc 100644 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from app_fixture import app, mock_user +from app_fixture import mock_user # type: ignore def test_post_requires_login(app): diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py deleted file mode 100644 index 753c52ce31..0000000000 --- a/api/tests/integration_tests/model_runtime/gitee_ai/test_llm.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -from collections.abc import Generator - -import pytest - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageTool, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gitee_ai.llm.llm import GiteeAILargeLanguageModel - - -def test_predefined_models(): - model = GiteeAILargeLanguageModel() - model_schemas = model.predefined_models() - - assert len(model_schemas) >= 1 - assert isinstance(model_schemas[0], AIModelEntity) - - -def test_validate_credentials_for_chat_model(): - model = GiteeAILargeLanguageModel() - - with pytest.raises(CredentialsValidateFailedError): - # model name to gpt-3.5-turbo because of mocking - model.validate_credentials(model="gpt-3.5-turbo", credentials={"api_key": "invalid_key"}) - - model.validate_credentials( - model="Qwen2-7B-Instruct", - credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, - ) - - -def test_invoke_chat_model(): - model = GiteeAILargeLanguageModel() - - result = model.invoke( - model="Qwen2-7B-Instruct", - credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - model_parameters={ - "temperature": 0.0, - "top_p": 1.0, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "max_tokens": 10, - "stream": False, - }, - stop=["How"], - stream=False, - user="foo", - ) - - assert isinstance(result, LLMResult) - assert len(result.message.content) > 0 - - -def test_invoke_stream_chat_model(): - model = GiteeAILargeLanguageModel() - - result = model.invoke( - model="Qwen2-7B-Instruct", - credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - model_parameters={"temperature": 0.0, "max_tokens": 100, "stream": False}, - stream=True, - user="foo", - ) - - assert isinstance(result, Generator) - - for chunk in result: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True - if chunk.delta.finish_reason is not None: - assert chunk.delta.usage is not None - - -def test_get_num_tokens(): - model = GiteeAILargeLanguageModel() - - num_tokens = model.get_num_tokens( - model="Qwen2-7B-Instruct", - credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, - prompt_messages=[UserPromptMessage(content="Hello World!")], - ) - - assert num_tokens == 10 - - num_tokens = model.get_num_tokens( - model="Qwen2-7B-Instruct", - credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - tools=[ - PromptMessageTool( - name="get_weather", - description="Determine weather in my location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, - "unit": {"type": "string", "enum": ["c", "f"]}, - }, - "required": ["location"], - }, - ), - ], - ) - - assert num_tokens == 77 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py deleted file mode 100644 index f12ed54a45..0000000000 --- a/api/tests/integration_tests/model_runtime/gitee_ai/test_provider.py +++ /dev/null @@ -1,15 +0,0 @@ -import os - -import pytest - -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gitee_ai.gitee_ai import GiteeAIProvider - - -def test_validate_provider_credentials(): - provider = GiteeAIProvider() - - with pytest.raises(CredentialsValidateFailedError): - provider.validate_provider_credentials(credentials={"api_key": "invalid_key"}) - - provider.validate_provider_credentials(credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py deleted file mode 100644 index 0e5914a61f..0000000000 --- a/api/tests/integration_tests/model_runtime/gitee_ai/test_rerank.py +++ /dev/null @@ -1,47 +0,0 @@ -import os - -import pytest - -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gitee_ai.rerank.rerank import GiteeAIRerankModel - - -def test_validate_credentials(): - model = GiteeAIRerankModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model="bge-reranker-v2-m3", - credentials={"api_key": "invalid_key"}, - ) - - model.validate_credentials( - model="bge-reranker-v2-m3", - credentials={ - "api_key": os.environ.get("GITEE_AI_API_KEY"), - }, - ) - - -def test_invoke_model(): - model = GiteeAIRerankModel() - result = model.invoke( - model="bge-reranker-v2-m3", - credentials={ - "api_key": os.environ.get("GITEE_AI_API_KEY"), - }, - query="What is the capital of the United States?", - docs=[ - "Carson City is the capital city of the American state of Nevada. At the 2010 United States " - "Census, Carson City had a population of 55,274.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " - "are a political division controlled by the United States. Its capital is Saipan.", - ], - top_n=1, - score_threshold=0.01, - ) - - assert isinstance(result, RerankResult) - assert len(result.docs) == 1 - assert result.docs[0].score >= 0.01 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py deleted file mode 100644 index 4a01453fdd..0000000000 --- a/api/tests/integration_tests/model_runtime/gitee_ai/test_speech2text.py +++ /dev/null @@ -1,45 +0,0 @@ -import os - -import pytest - -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gitee_ai.speech2text.speech2text import GiteeAISpeech2TextModel - - -def test_validate_credentials(): - model = GiteeAISpeech2TextModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model="whisper-base", - credentials={"api_key": "invalid_key"}, - ) - - model.validate_credentials( - model="whisper-base", - credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, - ) - - -def test_invoke_model(): - model = GiteeAISpeech2TextModel() - - # Get the directory of the current file - current_dir = os.path.dirname(os.path.abspath(__file__)) - - # Get assets directory - assets_dir = os.path.join(os.path.dirname(current_dir), "assets") - - # Construct the path to the audio file - audio_file_path = os.path.join(assets_dir, "audio.mp3") - - # Open the file and get the file object - with open(audio_file_path, "rb") as audio_file: - file = audio_file - - result = model.invoke( - model="whisper-base", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}, file=file - ) - - assert isinstance(result, str) - assert result == "1 2 3 4 5 6 7 8 9 10" diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py deleted file mode 100644 index 34648f0bc8..0000000000 --- a/api/tests/integration_tests/model_runtime/gitee_ai/test_text_embedding.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -import pytest - -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gitee_ai.text_embedding.text_embedding import GiteeAIEmbeddingModel - - -def test_validate_credentials(): - model = GiteeAIEmbeddingModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": "invalid_key"}) - - model.validate_credentials(model="bge-large-zh-v1.5", credentials={"api_key": os.environ.get("GITEE_AI_API_KEY")}) - - -def test_invoke_model(): - model = GiteeAIEmbeddingModel() - - result = model.invoke( - model="bge-large-zh-v1.5", - credentials={ - "api_key": os.environ.get("GITEE_AI_API_KEY"), - }, - texts=["hello", "world"], - user="user", - ) - - assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 2 - - -def test_get_num_tokens(): - model = GiteeAIEmbeddingModel() - - num_tokens = model.get_num_tokens( - model="bge-large-zh-v1.5", - credentials={ - "api_key": os.environ.get("GITEE_AI_API_KEY"), - }, - texts=["hello", "world"], - ) - - assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py b/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py deleted file mode 100644 index 9f18161a7b..0000000000 --- a/api/tests/integration_tests/model_runtime/gitee_ai/test_tts.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from core.model_runtime.model_providers.gitee_ai.tts.tts import GiteeAIText2SpeechModel - - -def test_invoke_model(): - model = GiteeAIText2SpeechModel() - - result = model.invoke( - model="speecht5_tts", - tenant_id="test", - credentials={ - "api_key": os.environ.get("GITEE_AI_API_KEY"), - }, - content_text="Hello, world!", - voice="", - ) - - content = b"" - for chunk in result: - content += chunk - - assert content != b"" diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py deleted file mode 100644 index f56ad0dadc..0000000000 --- a/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py +++ /dev/null @@ -1,49 +0,0 @@ -import os - -import pytest - -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import ( - GPUStackTextEmbeddingModel, -) - - -def test_validate_credentials(): - model = GPUStackTextEmbeddingModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model="bge-m3", - credentials={ - "endpoint_url": "invalid_url", - "api_key": "invalid_api_key", - }, - ) - - model.validate_credentials( - model="bge-m3", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - }, - ) - - -def test_invoke_model(): - model = GPUStackTextEmbeddingModel() - - result = model.invoke( - model="bge-m3", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "context_size": 8192, - }, - texts=["hello", "world"], - user="abc-123", - ) - - assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_llm.py b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py deleted file mode 100644 index 326b7b16f0..0000000000 --- a/api/tests/integration_tests/model_runtime/gpustack/test_llm.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -from collections.abc import Generator - -import pytest - -from core.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, -) -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageTool, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel - - -def test_validate_credentials_for_chat_model(): - model = GPUStackLanguageModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model="llama-3.2-1b-instruct", - credentials={ - "endpoint_url": "invalid_url", - "api_key": "invalid_api_key", - "mode": "chat", - }, - ) - - model.validate_credentials( - model="llama-3.2-1b-instruct", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "mode": "chat", - }, - ) - - -def test_invoke_completion_model(): - model = GPUStackLanguageModel() - - response = model.invoke( - model="llama-3.2-1b-instruct", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "mode": "completion", - }, - prompt_messages=[UserPromptMessage(content="ping")], - model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, - stop=[], - user="abc-123", - stream=False, - ) - - assert isinstance(response, LLMResult) - assert len(response.message.content) > 0 - assert response.usage.total_tokens > 0 - - -def test_invoke_chat_model(): - model = GPUStackLanguageModel() - - response = model.invoke( - model="llama-3.2-1b-instruct", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "mode": "chat", - }, - prompt_messages=[UserPromptMessage(content="ping")], - model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, - stop=[], - user="abc-123", - stream=False, - ) - - assert isinstance(response, LLMResult) - assert len(response.message.content) > 0 - assert response.usage.total_tokens > 0 - - -def test_invoke_stream_chat_model(): - model = GPUStackLanguageModel() - - response = model.invoke( - model="llama-3.2-1b-instruct", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "mode": "chat", - }, - prompt_messages=[UserPromptMessage(content="Hello World!")], - model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, - stop=["you"], - stream=True, - user="abc-123", - ) - - assert isinstance(response, Generator) - for chunk in response: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True - - -def test_get_num_tokens(): - model = GPUStackLanguageModel() - - num_tokens = model.get_num_tokens( - model="????", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "mode": "chat", - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - tools=[ - PromptMessageTool( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["c", "f"]}, - }, - "required": ["location"], - }, - ) - ], - ) - - assert isinstance(num_tokens, int) - assert num_tokens == 80 - - num_tokens = model.get_num_tokens( - model="????", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - "mode": "chat", - }, - prompt_messages=[UserPromptMessage(content="Hello World!")], - ) - - assert isinstance(num_tokens, int) - assert num_tokens == 10 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py deleted file mode 100644 index f5c2d2d21c..0000000000 --- a/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py +++ /dev/null @@ -1,107 +0,0 @@ -import os - -import pytest - -from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.gpustack.rerank.rerank import ( - GPUStackRerankModel, -) - - -def test_validate_credentials_for_rerank_model(): - model = GPUStackRerankModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model="bge-reranker-v2-m3", - credentials={ - "endpoint_url": "invalid_url", - "api_key": "invalid_api_key", - }, - ) - - model.validate_credentials( - model="bge-reranker-v2-m3", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - }, - ) - - -def test_invoke_rerank_model(): - model = GPUStackRerankModel() - - response = model.invoke( - model="bge-reranker-v2-m3", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - }, - query="Organic skincare products for sensitive skin", - docs=[ - "Eco-friendly kitchenware for modern homes", - "Biodegradable cleaning supplies for eco-conscious consumers", - "Organic cotton baby clothes for sensitive skin", - "Natural organic skincare range for sensitive skin", - "Tech gadgets for smart homes: 2024 edition", - "Sustainable gardening tools and compost solutions", - "Sensitive skin-friendly facial cleansers and toners", - "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials", - ], - top_n=3, - score_threshold=-0.75, - user="abc-123", - ) - - assert isinstance(response, RerankResult) - assert len(response.docs) == 3 - - -def test__invoke(): - model = GPUStackRerankModel() - - # Test case 1: Empty docs - result = model._invoke( - model="bge-reranker-v2-m3", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - }, - query="Organic skincare products for sensitive skin", - docs=[], - top_n=3, - score_threshold=0.75, - user="abc-123", - ) - assert isinstance(result, RerankResult) - assert len(result.docs) == 0 - - # Test case 2: Expected docs - result = model._invoke( - model="bge-reranker-v2-m3", - credentials={ - "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), - "api_key": os.environ.get("GPUSTACK_API_KEY"), - }, - query="Organic skincare products for sensitive skin", - docs=[ - "Eco-friendly kitchenware for modern homes", - "Biodegradable cleaning supplies for eco-conscious consumers", - "Organic cotton baby clothes for sensitive skin", - "Natural organic skincare range for sensitive skin", - "Tech gadgets for smart homes: 2024 edition", - "Sustainable gardening tools and compost solutions", - "Sensitive skin-friendly facial cleansers and toners", - "Organic food wraps and storage solutions", - "Yoga mats made from recycled materials", - ], - top_n=3, - score_threshold=-0.75, - user="abc-123", - ) - assert isinstance(result, RerankResult) - assert len(result.docs) == 3 - assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py deleted file mode 100644 index 7797d0f8e4..0000000000 --- a/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -from collections.abc import Generator - -import pytest - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel - - -def test_validate_credentials(): - model = VesslAILargeLanguageModel() - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model=os.environ.get("VESSL_AI_MODEL_NAME"), - credentials={ - "api_key": "invalid_key", - "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), - "mode": "chat", - }, - ) - - with pytest.raises(CredentialsValidateFailedError): - model.validate_credentials( - model=os.environ.get("VESSL_AI_MODEL_NAME"), - credentials={ - "api_key": os.environ.get("VESSL_AI_API_KEY"), - "endpoint_url": "http://invalid_url", - "mode": "chat", - }, - ) - - model.validate_credentials( - model=os.environ.get("VESSL_AI_MODEL_NAME"), - credentials={ - "api_key": os.environ.get("VESSL_AI_API_KEY"), - "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), - "mode": "chat", - }, - ) - - -def test_invoke_model(): - model = VesslAILargeLanguageModel() - - response = model.invoke( - model=os.environ.get("VESSL_AI_MODEL_NAME"), - credentials={ - "api_key": os.environ.get("VESSL_AI_API_KEY"), - "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), - "mode": "chat", - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Who are you?"), - ], - model_parameters={ - "temperature": 1.0, - "top_k": 2, - "top_p": 0.5, - }, - stop=["How"], - stream=False, - user="abc-123", - ) - - assert isinstance(response, LLMResult) - assert len(response.message.content) > 0 - - -def test_invoke_stream_model(): - model = VesslAILargeLanguageModel() - - response = model.invoke( - model=os.environ.get("VESSL_AI_MODEL_NAME"), - credentials={ - "api_key": os.environ.get("VESSL_AI_API_KEY"), - "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), - "mode": "chat", - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Who are you?"), - ], - model_parameters={ - "temperature": 1.0, - "top_k": 2, - "top_p": 0.5, - }, - stop=["How"], - stream=True, - user="abc-123", - ) - - assert isinstance(response, Generator) - - for chunk in response: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - - -def test_get_num_tokens(): - model = VesslAILargeLanguageModel() - - num_tokens = model.get_num_tokens( - model=os.environ.get("VESSL_AI_MODEL_NAME"), - credentials={ - "api_key": os.environ.get("VESSL_AI_API_KEY"), - "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - ) - - assert isinstance(num_tokens, int) - assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py deleted file mode 100644 index 33c803e8e1..0000000000 --- a/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py +++ /dev/null @@ -1,21 +0,0 @@ -import os -from time import sleep - -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel - - -def test_invoke_bce_reranker_base_v1(): - sleep(3) - model = WenxinRerankModel() - - response = model.invoke( - model="bce-reranker-base_v1", - credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, - query="What is Deep Learning?", - docs=["Deep Learning is ...", "My Book is ..."], - user="abc-123", - ) - - assert isinstance(response, RerankResult) - assert len(response.docs) == 2 diff --git a/api/tests/integration_tests/model_runtime/x/test_llm.py b/api/tests/integration_tests/model_runtime/x/test_llm.py deleted file mode 100644 index 647a2f6480..0000000000 --- a/api/tests/integration_tests/model_runtime/x/test_llm.py +++ /dev/null @@ -1,204 +0,0 @@ -import os -from collections.abc import Generator - -import pytest - -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageTool, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel - -"""FOR MOCK FIXTURES, DO NOT REMOVE""" -from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock - - -def test_predefined_models(): - model = XAILargeLanguageModel() - model_schemas = model.predefined_models() - - assert len(model_schemas) >= 1 - assert isinstance(model_schemas[0], AIModelEntity) - - -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_validate_credentials_for_chat_model(setup_openai_mock): - model = XAILargeLanguageModel() - - with pytest.raises(CredentialsValidateFailedError): - # model name to gpt-3.5-turbo because of mocking - model.validate_credentials( - model="gpt-3.5-turbo", - credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, - ) - - model.validate_credentials( - model="grok-beta", - credentials={ - "api_key": os.environ.get("XAI_API_KEY"), - "endpoint_url": os.environ.get("XAI_API_BASE"), - "mode": "chat", - }, - ) - - -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_invoke_chat_model(setup_openai_mock): - model = XAILargeLanguageModel() - - result = model.invoke( - model="grok-beta", - credentials={ - "api_key": os.environ.get("XAI_API_KEY"), - "endpoint_url": os.environ.get("XAI_API_BASE"), - "mode": "chat", - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - model_parameters={ - "temperature": 0.0, - "top_p": 1.0, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "max_tokens": 10, - }, - stop=["How"], - stream=False, - user="foo", - ) - - assert isinstance(result, LLMResult) - assert len(result.message.content) > 0 - - -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_invoke_chat_model_with_tools(setup_openai_mock): - model = XAILargeLanguageModel() - - result = model.invoke( - model="grok-beta", - credentials={ - "api_key": os.environ.get("XAI_API_KEY"), - "endpoint_url": os.environ.get("XAI_API_BASE"), - "mode": "chat", - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage( - content="what's the weather today in London?", - ), - ], - model_parameters={"temperature": 0.0, "max_tokens": 100}, - tools=[ - PromptMessageTool( - name="get_weather", - description="Determine weather in my location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, - "unit": {"type": "string", "enum": ["c", "f"]}, - }, - "required": ["location"], - }, - ), - PromptMessageTool( - name="get_stock_price", - description="Get the current stock price", - parameters={ - "type": "object", - "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, - "required": ["symbol"], - }, - ), - ], - stream=False, - user="foo", - ) - - assert isinstance(result, LLMResult) - assert isinstance(result.message, AssistantPromptMessage) - - -@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) -def test_invoke_stream_chat_model(setup_openai_mock): - model = XAILargeLanguageModel() - - result = model.invoke( - model="grok-beta", - credentials={ - "api_key": os.environ.get("XAI_API_KEY"), - "endpoint_url": os.environ.get("XAI_API_BASE"), - "mode": "chat", - }, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - model_parameters={"temperature": 0.0, "max_tokens": 100}, - stream=True, - user="foo", - ) - - assert isinstance(result, Generator) - - for chunk in result: - assert isinstance(chunk, LLMResultChunk) - assert isinstance(chunk.delta, LLMResultChunkDelta) - assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True - if chunk.delta.finish_reason is not None: - assert chunk.delta.usage is not None - assert chunk.delta.usage.completion_tokens > 0 - - -def test_get_num_tokens(): - model = XAILargeLanguageModel() - - num_tokens = model.get_num_tokens( - model="grok-beta", - credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, - prompt_messages=[UserPromptMessage(content="Hello World!")], - ) - - assert num_tokens == 10 - - num_tokens = model.get_num_tokens( - model="grok-beta", - credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, - prompt_messages=[ - SystemPromptMessage( - content="You are a helpful AI assistant.", - ), - UserPromptMessage(content="Hello World!"), - ], - tools=[ - PromptMessageTool( - name="get_weather", - description="Determine weather in my location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, - "unit": {"type": "string", "enum": ["c", "f"]}, - }, - "required": ["location"], - }, - ), - ], - ) - - assert num_tokens == 77 diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce9..2860739f0e 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restful import Api, Resource # type: ignore app = Flask(__name__) api = Api(app) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index 64b748d6e9..9acc94e110 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -43,9 +43,9 @@ def test_api_tool(setup_http_mock): response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters) assert response.status_code == 200 - assert "/p_param" == response.request.url.path - assert b"query_param=q_param" == response.request.url.query - assert "h_param" == response.request.headers.get("header_param") - assert "application/json" == response.request.headers.get("content-type") - assert "cookie_param=c_param" == response.request.headers.get("cookie") + assert response.request.url.path == "/p_param" + assert response.request.url.query == b"query_param=q_param" + assert response.request.headers.get("header_param") == "h_param" + assert response.request.headers.get("content-type") == "application/json" + assert response.request.headers.get("cookie") == "cookie_param=c_param" assert "b_param" in response.content.decode() diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 0ea61369c0..4af35a8bef 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient -from pymochow.model.database import Database -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState -from pymochow.model.schema import HNSWParams, VectorIndex -from pymochow.model.table import Table +from pymochow import MochowClient # type: ignore +from pymochow.model.database import Database # type: ignore +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore +from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore +from pymochow.model.table import Table # type: ignore from requests.adapters import HTTPAdapter diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 61d6ed1656..68a1e290ad 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -4,12 +4,12 @@ from typing import Optional import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter -from tcvectordb import VectorDBClient -from tcvectordb.model.database import Collection, Database -from tcvectordb.model.document import Document, Filter -from tcvectordb.model.enum import ReadConsistency -from tcvectordb.model.index import Index -from xinference_client.types import Embedding +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model.database import Collection, Database # type: ignore +from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.enum import ReadConsistency # type: ignore +from tcvectordb.model.index import Index # type: ignore +from xinference_client.types import Embedding # type: ignore class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 0f40337feb..3ad72e5550 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Collection, Data, DistanceType, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 9eea63f722..0507fc7075 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -384,7 +384,7 @@ def test_mock_404(setup_http_mock): assert result.outputs is not None resp = result.outputs - assert 404 == resp.get("status_code") + assert resp.get("status_code") == 404 assert "Not Found" in resp.get("body", "") diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 385eb08c36..efa9ea8979 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -59,6 +59,8 @@ def test_dify_config(example_env_file): # annotated field with configured value assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 + assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3 + # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. diff --git a/api/tests/unit_tests/configs/test_opendal_config_parse.py b/api/tests/unit_tests/configs/test_opendal_config_parse.py deleted file mode 100644 index 94de40450b..0000000000 --- a/api/tests/unit_tests/configs/test_opendal_config_parse.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from extensions.storage.opendal_storage import is_r2_endpoint - - -@pytest.mark.parametrize( - ("endpoint", "expected"), - [ - ("https://bucket.r2.cloudflarestorage.com", True), - ("https://custom-domain.r2.cloudflarestorage.com/", True), - ("https://bucket.r2.cloudflarestorage.com/path", True), - ("https://s3.amazonaws.com", False), - ("https://storage.googleapis.com", False), - ("http://localhost:9000", False), - ("invalid-url", False), - ("", False), - ], -) -def test_is_r2_endpoint(endpoint: str, expected: bool): - assert is_r2_endpoint(endpoint) == expected diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 0c264c15a0..426557c716 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -2,6 +2,8 @@ import pytest from pydantic import ValidationError from core.variables import ( + ArrayFileVariable, + ArrayVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -81,3 +83,8 @@ def test_variable_to_object(): assert var.to_object() == 3.14 var = SecretVariable(name="secret", value="secret_value") assert var.to_object() == "secret_value" + + +def test_array_file_variable_is_array_variable(): + var = ArrayFileVariable(name="files", value=[]) + assert isinstance(var, ArrayVariable) diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 7d19cff3e8..ee0f7672f8 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest +from configs import dify_config from core.app.app_config.entities import ModelConfigEntity from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig from core.memory.token_buffer_memory import TokenBufferMemory @@ -126,6 +127,7 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): model_config_mock, _, messages, inputs, context = get_chat_model_args + dify_config.MULTIMODAL_SEND_FORMAT = "url" files = [ File( @@ -134,13 +136,16 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", + storage_key="", ) ] prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string: - mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url)) + mock_get_encoded_string.return_value = ImagePromptMessageContent( + url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" + ) prompt_messages = prompt_transform._get_chat_model_prompt_messages( prompt_template=messages, inputs=inputs, diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 4edbc01cc7..e02d882780 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,34 +1,9 @@ import json -from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig +from core.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow -def test_file_loads_and_dumps(): - file = File( - id="file1", - tenant_id="tenant1", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image1.jpg", - ) - - file_dict = file.model_dump() - assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY - assert file_dict["type"] == file.type.value - assert isinstance(file_dict["type"], str) - assert file_dict["transfer_method"] == file.transfer_method.value - assert isinstance(file_dict["transfer_method"], str) - assert "_extra_config" not in file_dict - - file_obj = File.model_validate(file_dict) - assert file_obj.id == file.id - assert file_obj.tenant_id == file.tenant_id - assert file_obj.type == file.type - assert file_obj.transfer_method == file.transfer_method - assert file_obj.remote_url == file.remote_url - - def test_file_to_dict(): file = File( id="file1", @@ -36,10 +11,11 @@ def test_file_to_dict(): type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", + storage_key="storage_key", ) file_dict = file.to_dict() - assert "_extra_config" not in file_dict + assert "_storage_key" not in file_dict assert "url" in file_dict diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 9f1ba7b6af..b7d8f69e8c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -488,14 +488,12 @@ def test_run_branch(mock_close, mock_remove): items = [] generator = graph_engine.run() for item in generator: - # print(type(item), item) items.append(item) assert len(items) == 10 assert items[3].route_node_state.node_id == "if-else-1" assert items[4].route_node_state.node_id == "if-else-1" assert isinstance(items[5], NodeRunStreamChunkEvent) - assert items[5].chunk_content == "1 " assert isinstance(items[6], NodeRunStreamChunkEvent) assert items[6].chunk_content == "takato" assert items[7].route_node_state.node_id == "answer-1" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 70ec023140..97bacada74 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch): type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1111", + storage_key="", ), ), ) @@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch): type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1111", + storage_key="", ), ), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 9a24d35a1f..76db42ef10 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -18,11 +18,11 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute @@ -158,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node): filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + storage_key="", ) llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) @@ -174,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node): filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + storage_key="", ), File( id="2", @@ -182,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node): filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", + storage_key="", ), ] llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) @@ -225,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + storage_key="", ) ] fake_query = faker.sentence() prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=files, + sys_query=fake_query, + sys_files=files, context=None, memory=None, model_config=model_config, @@ -249,8 +253,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Setup dify config - dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" - dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" + dify_config.MULTIMODAL_SEND_FORMAT = "url" # Generate fake values for prompt template fake_assistant_prompt = faker.sentence() @@ -285,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): test_scenarios = [ LLMNodeTestScenario( description="No files", - user_query=fake_query, - user_files=[], + sys_query=fake_query, + sys_files=[], features=[], vision_enabled=False, vision_detail=None, @@ -320,14 +323,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), LLMNodeTestScenario( description="User files", - user_query=fake_query, - user_files=[ + sys_query=fake_query, + sys_files=[ File( tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + extension=".jpg", + mime_type="image/jpg", + storage_key="", ) ], vision_enabled=True, @@ -361,15 +367,17 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): UserPromptMessage( content=[ TextPromptMessageContent(data=fake_query), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ImagePromptMessageContent( + url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail + ), ] ), ], ), LLMNodeTestScenario( description="Prompt template with variable selector of File", - user_query=fake_query, - user_files=[], + sys_query=fake_query, + sys_files=[], vision_enabled=False, vision_detail=fake_vision_detail, features=[ModelFeature.VISION], @@ -384,7 +392,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): expected_messages=[ UserPromptMessage( content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ImagePromptMessageContent( + url=fake_remote_url, mime_type="image/jpg", format="jpg", detail=fake_vision_detail + ), ] ), ] @@ -397,6 +407,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + extension=".jpg", + mime_type="image/jpg", + storage_key="", ) }, ), @@ -411,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=scenario.user_query, - user_files=scenario.user_files, + sys_query=scenario.sys_query, + sys_files=scenario.sys_files, context=fake_context, memory=memory, model_config=model_config, @@ -429,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): assert ( prompt_messages == scenario.expected_messages ), f"Message content mismatch in scenario: {scenario.description}" + + +def test_handle_list_messages_basic(llm_node): + messages = [ + LLMNodeChatModelMessage( + text="Hello, {#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ] + context = "world" + jinja2_variables = [] + variable_pool = llm_node.graph_runtime_state.variable_pool + vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH + + result = llm_node._handle_list_messages( + messages=messages, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail_config, + ) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == [TextPromptMessageContent(data="Hello, world")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index 8e39445baf..21bb857353 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel): """Test scenario for LLM node testing.""" description: str = Field(..., description="Description of the test scenario") - user_query: str = Field(..., description="User query input") - user_files: Sequence[File] = Field(default_factory=list, description="List of user files") + sys_query: str = Field(..., description="User query input") + sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") vision_enabled: bool = Field(default=False, description="Whether vision is enabled") vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ba209e4020..2d74be9da9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, NodeRunExceptionEvent, NodeRunStreamChunkEvent, ) @@ -14,7 +13,9 @@ from models.workflow import WorkflowType class ContinueOnErrorTestHelper: @staticmethod - def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): + def get_code_node( + code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} + ): """Helper method to create a code node configuration""" node = { "id": "node", @@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper: "code_language": "python3", "code": "\n".join([line[4:] for line in code.split("\n")]), "type": "code", + **retry_config, }, } if default_value: @@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper: @staticmethod def get_http_node( - error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False + error_strategy: str = "fail-branch", + default_value: dict | None = None, + authorization_success: bool = False, + retry_config: dict = {}, ): """Helper method to create a http node configuration""" authorization = ( @@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper: "body": None, "type": "http-request", "error_strategy": error_strategy, + **retry_config, }, } if default_value: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index d964d0e352..41e2c5d484 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -248,6 +248,7 @@ def test_array_file_contains_file_name(): transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", + storage_key="", ), ], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d20dfc5b31..36116d3540 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", + storage_key="", ), File( filename="document1.pdf", @@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", + storage_key="", ), File( filename="image2.png", @@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", + storage_key="", ), File( filename="audio1.mp3", @@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node): tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", + storage_key="", ), ] variable = ArrayFileSegment(value=files) @@ -130,6 +134,7 @@ def test_get_file_extract_string_func(): mime_type="text/plain", remote_url="https://example.com/test_file.txt", related_id="test_related_id", + storage_key="", ) # Test each case @@ -150,6 +155,7 @@ def test_get_file_extract_string_func(): mime_type=None, remote_url=None, related_id="test_related_id", + storage_key="", ) assert _get_file_extract_string_func(key="name")(empty_file) == "" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py new file mode 100644 index 0000000000..c232875ce5 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_retry.py @@ -0,0 +1,73 @@ +from core.workflow.graph_engine.entities.event import ( + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunSucceededEvent, + NodeRunRetryEvent, +) +from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + + +def test_retry_default_value_partial_success(): + """retry default value node with partial success status""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", + [{"key": "result", "type": "string", "value": "http node got error response"}], + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert events[-1].outputs == {"answer": "http node got error response"} + assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) + assert len(events) == 11 + + +def test_retry_failed(): + """retry failed with success status""" + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + None, + None, + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert any(isinstance(e, GraphRunFailedEvent) for e in events) + assert len(events) == 8 diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9ea6acac17..efbcdc760c 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -19,6 +19,7 @@ def file(): related_id="test_related_id", remote_url="test_url", filename="test_file.txt", + storage_key="", ) diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py index 27e1c0ad85..4f6d8a2f54 100644 --- a/api/tests/unit_tests/oss/__mock/aliyun_oss.py +++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from oss2 import Bucket -from oss2.models import GetObjectResult, PutObjectResult +from oss2 import Bucket # type: ignore +from oss2.models import GetObjectResult, PutObjectResult # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index 5189b68e87..c77c5b08f3 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client -from qcloud_cos.streambody import StreamBody +from qcloud_cos import CosS3Client # type: ignore +from qcloud_cos.streambody import StreamBody # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 649d93a202..88df59f91c 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput +from tos import TosClientV2 # type: ignore +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py index 65d31352bd..380134bc46 100644 --- a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py +++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from oss2 import Auth +from oss2 import Auth # type: ignore from extensions.storage.aliyun_oss_storage import AliyunOssStorage from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock diff --git a/api/tests/unit_tests/oss/opendal/test_opendal.py b/api/tests/unit_tests/oss/opendal/test_opendal.py index 1caee55677..6acec6e579 100644 --- a/api/tests/unit_tests/oss/opendal/test_opendal.py +++ b/api/tests/unit_tests/oss/opendal/test_opendal.py @@ -1,15 +1,12 @@ -import os from collections.abc import Generator from pathlib import Path import pytest -from configs.middleware.storage.opendal_storage_config import OpenDALScheme from extensions.storage.opendal_storage import OpenDALStorage from tests.unit_tests.oss.__mock.base import ( get_example_data, get_example_filename, - get_example_filepath, get_opendal_bucket, ) @@ -19,7 +16,7 @@ class TestOpenDAL: def setup_method(self, *args, **kwargs): """Executed before each test method.""" self.storage = OpenDALStorage( - scheme=OpenDALScheme.FS, + scheme="fs", root=get_opendal_bucket(), ) diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493bd..d289751800 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig +from qcloud_cos import CosConfig # type: ignore from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 5afbc9e8b4..04988e85d8 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,5 +1,5 @@ import pytest -from tos import TosClientV2 +from tos import TosClientV2 # type: ignore from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d5..8d64548727 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -1,7 +1,7 @@ from textwrap import dedent import pytest -from yaml import YAMLError +from yaml import YAMLError # type: ignore from core.tools.utils.yaml_utils import load_yaml_file diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py new file mode 100644 index 0000000000..08adc9ebe9 --- /dev/null +++ b/dev/pytest/pytest_config_tests.py @@ -0,0 +1,111 @@ +import yaml # type: ignore +from dotenv import dotenv_values +from pathlib import Path + +BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { + "APP_MAX_EXECUTION_TIME", + "BATCH_UPLOAD_LIMIT", + "CELERY_BEAT_SCHEDULER_TIME", + "CODE_EXECUTION_API_KEY", + "HTTP_REQUEST_MAX_CONNECT_TIMEOUT", + "HTTP_REQUEST_MAX_READ_TIMEOUT", + "HTTP_REQUEST_MAX_WRITE_TIMEOUT", + "KEYWORD_DATA_SOURCE_TYPE", + "LOGIN_LOCKOUT_DURATION", + "LOG_FORMAT", + "OCI_ACCESS_KEY", + "OCI_BUCKET_NAME", + "OCI_ENDPOINT", + "OCI_REGION", + "OCI_SECRET_KEY", + "REDIS_DB", + "RESEND_API_URL", + "RESPECT_XFORWARD_HEADERS_ENABLED", + "SENTRY_DSN", + "SSRF_DEFAULT_CONNECT_TIME_OUT", + "SSRF_DEFAULT_MAX_RETRIES", + "SSRF_DEFAULT_READ_TIME_OUT", + "SSRF_DEFAULT_TIME_OUT", + "SSRF_DEFAULT_WRITE_TIME_OUT", + "UPSTASH_VECTOR_TOKEN", + "UPSTASH_VECTOR_URL", + "USING_UGC_INDEX", + "WEAVIATE_BATCH_SIZE", + "WEAVIATE_GRPC_ENABLED", +} + +BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { + "BATCH_UPLOAD_LIMIT", + "CELERY_BEAT_SCHEDULER_TIME", + "HTTP_REQUEST_MAX_CONNECT_TIMEOUT", + "HTTP_REQUEST_MAX_READ_TIMEOUT", + "HTTP_REQUEST_MAX_WRITE_TIMEOUT", + "KEYWORD_DATA_SOURCE_TYPE", + "LOGIN_LOCKOUT_DURATION", + "LOG_FORMAT", + "OPENDAL_FS_ROOT", + "OPENDAL_S3_ACCESS_KEY_ID", + "OPENDAL_S3_BUCKET", + "OPENDAL_S3_ENDPOINT", + "OPENDAL_S3_REGION", + "OPENDAL_S3_ROOT", + "OPENDAL_S3_SECRET_ACCESS_KEY", + "OPENDAL_S3_SERVER_SIDE_ENCRYPTION", + "PGVECTOR_MAX_CONNECTION", + "PGVECTOR_MIN_CONNECTION", + "PGVECTO_RS_DATABASE", + "PGVECTO_RS_HOST", + "PGVECTO_RS_PASSWORD", + "PGVECTO_RS_PORT", + "PGVECTO_RS_USER", + "RESPECT_XFORWARD_HEADERS_ENABLED", + "SCARF_NO_ANALYTICS", + "SSRF_DEFAULT_CONNECT_TIME_OUT", + "SSRF_DEFAULT_MAX_RETRIES", + "SSRF_DEFAULT_READ_TIME_OUT", + "SSRF_DEFAULT_TIME_OUT", + "SSRF_DEFAULT_WRITE_TIME_OUT", + "STORAGE_OPENDAL_SCHEME", + "SUPABASE_API_KEY", + "SUPABASE_BUCKET_NAME", + "SUPABASE_URL", + "USING_UGC_INDEX", + "VIKINGDB_CONNECTION_TIMEOUT", + "VIKINGDB_SOCKET_TIMEOUT", + "WEAVIATE_BATCH_SIZE", + "WEAVIATE_GRPC_ENABLED", +} + +API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys()) +DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys()) +DOCKER_COMPOSE_CONFIG_SET = set() + +with open(Path("docker") / Path("docker-compose.yaml")) as f: + DOCKER_COMPOSE_CONFIG_SET = set(yaml.safe_load(f.read())["x-shared-env"].keys()) + + +def test_yaml_config(): + # python set == operator is used to compare two sets + DIFF_API_WITH_DOCKER = ( + API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF + ) + if DIFF_API_WITH_DOCKER: + print( + f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}" + ) + raise Exception("API and Docker config sets are different") + DIFF_API_WITH_DOCKER_COMPOSE = ( + API_CONFIG_SET + - DOCKER_COMPOSE_CONFIG_SET + - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF + ) + if DIFF_API_WITH_DOCKER_COMPOSE: + print( + f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}" + ) + raise Exception("API and Docker Compose config sets are different") + print("All tests passed!") + + +if __name__ == "__main__": + test_yaml_config() diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 4392407a73..1cff58be7f 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.13.2 + image: langgenius/dify-api:0.14.2 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.13.2 + image: langgenius/dify-api:0.14.2 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.13.2 + image: langgenius/dify-web:0.14.2 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index db85e5d511..43e67a8db4 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -107,6 +107,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60 # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. APP_MAX_ACTIVE_REQUESTS=0 +APP_MAX_EXECUTION_TIME=1200 # ------------------------------ # Container Startup Related Configuration @@ -119,15 +120,15 @@ DIFY_BIND_ADDRESS=0.0.0.0 # API service binding port number, default 5001. DIFY_PORT=5001 -# The number of API server workers, i.e., the number of gevent workers. -# Formula: number of cpu cores x 2 + 1 +# The number of API server workers, i.e., the number of workers. +# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers SERVER_WORKER_AMOUNT= # Defaults to gevent. If using windows, it can be switched to sync or solo. SERVER_WORKER_CLASS= -# Similar to SERVER_WORKER_CLASS. Default is gevent. +# Similar to SERVER_WORKER_CLASS. # If using windows, it can be switched to sync or solo. CELERY_WORKER_CLASS= @@ -227,6 +228,7 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +REDIS_DB=0 # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -281,57 +283,42 @@ CONSOLE_CORS_ALLOW_ORIGINS=* # ------------------------------ # The type of storage to use for storing user files. -# Supported values are `opendal` , `s3` , `azure-blob` , `google-storage`, `tencent-cos`, `huawei-obs`, `volcengine-tos`, `baidu-obs`, `supabase` -# Default: `opendal` STORAGE_TYPE=opendal -# Apache OpenDAL Configuration, refer to https://github.com/apache/opendal -# The scheme for the OpenDAL storage. -STORAGE_OPENDAL_SCHEME=fs -# OpenDAL FS +# Apache OpenDAL Configuration +# The configuration for OpenDAL consists of the following format: OPENDAL__. +# You can find all the service configurations (CONFIG_NAME) in the repository at: https://github.com/apache/opendal/tree/main/core/src/services. +# Dify will scan configurations starting with OPENDAL_ and automatically apply them. +# The scheme name for the OpenDAL storage. +OPENDAL_SCHEME=fs +# Configurations for OpenDAL Local File System. OPENDAL_FS_ROOT=storage -# OpenDAL S3 -OPENDAL_S3_ROOT=/ -OPENDAL_S3_BUCKET=your-bucket-name -OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com -OPENDAL_S3_ACCESS_KEY_ID=your-access-key -OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key -OPENDAL_S3_REGION=your-region -OPENDAL_S3_SERVER_SIDE_ENCRYPTION= # S3 Configuration +# +S3_ENDPOINT= +S3_REGION=us-east-1 +S3_BUCKET_NAME=difyai +S3_ACCESS_KEY= +S3_SECRET_KEY= # Whether to use AWS managed IAM roles for authenticating with the S3 service. # If set to false, the access key and secret key must be provided. S3_USE_AWS_MANAGED_IAM=false -# The endpoint of the S3 service. -S3_ENDPOINT= -# The region of the S3 service. -S3_REGION=us-east-1 -# The name of the S3 bucket to use for storing files. -S3_BUCKET_NAME=difyai -# The access key to use for authenticating with the S3 service. -S3_ACCESS_KEY= -# The secret key to use for authenticating with the S3 service. -S3_SECRET_KEY= # Azure Blob Configuration -# The name of the Azure Blob Storage account to use for storing files. +# AZURE_BLOB_ACCOUNT_NAME=difyai -# The access key to use for authenticating with the Azure Blob Storage account. AZURE_BLOB_ACCOUNT_KEY=difyai -# The name of the Azure Blob Storage container to use for storing files. AZURE_BLOB_CONTAINER_NAME=difyai-container -# The URL of the Azure Blob Storage account. AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Google Storage Configuration -# The name of the Google Storage bucket to use for storing files. +# GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name -# The service account JSON key to use for authenticating with the Google Storage service. GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string # The Alibaba Cloud OSS configurations, -# only available when STORAGE_TYPE is `aliyun-oss` +# ALIYUN_OSS_BUCKET_NAME=your-bucket-name ALIYUN_OSS_ACCESS_KEY=your-access-key ALIYUN_OSS_SECRET_KEY=your-secret-key @@ -342,55 +329,47 @@ ALIYUN_OSS_AUTH_VERSION=v4 ALIYUN_OSS_PATH=your-path # Tencent COS Configuration -# The name of the Tencent COS bucket to use for storing files. +# TENCENT_COS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Tencent COS service. TENCENT_COS_SECRET_KEY=your-secret-key -# The secret id to use for authenticating with the Tencent COS service. TENCENT_COS_SECRET_ID=your-secret-id -# The region of the Tencent COS service. TENCENT_COS_REGION=your-region -# The scheme of the Tencent COS service. TENCENT_COS_SCHEME=your-scheme +# Oracle Storage Configuration +# +OCI_ENDPOINT=https://objectstorage.us-ashburn-1.oraclecloud.com +OCI_BUCKET_NAME=your-bucket-name +OCI_ACCESS_KEY=your-access-key +OCI_SECRET_KEY=your-secret-key +OCI_REGION=us-ashburn-1 + # Huawei OBS Configuration -# The name of the Huawei OBS bucket to use for storing files. +# HUAWEI_OBS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Huawei OBS service. HUAWEI_OBS_SECRET_KEY=your-secret-key -# The access key to use for authenticating with the Huawei OBS service. HUAWEI_OBS_ACCESS_KEY=your-access-key -# The server url of the HUAWEI OBS service. HUAWEI_OBS_SERVER=your-server-url # Volcengine TOS Configuration -# The name of the Volcengine TOS bucket to use for storing files. +# VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Volcengine TOS service. VOLCENGINE_TOS_SECRET_KEY=your-secret-key -# The access key to use for authenticating with the Volcengine TOS service. VOLCENGINE_TOS_ACCESS_KEY=your-access-key -# The endpoint of the Volcengine TOS service. VOLCENGINE_TOS_ENDPOINT=your-server-url -# The region of the Volcengine TOS service. VOLCENGINE_TOS_REGION=your-region # Baidu OBS Storage Configuration -# The name of the Baidu OBS bucket to use for storing files. +# BAIDU_OBS_BUCKET_NAME=your-bucket-name -# The secret key to use for authenticating with the Baidu OBS service. BAIDU_OBS_SECRET_KEY=your-secret-key -# The access key to use for authenticating with the Baidu OBS service. BAIDU_OBS_ACCESS_KEY=your-access-key -# The endpoint of the Baidu OBS service. BAIDU_OBS_ENDPOINT=your-server-url # Supabase Storage Configuration -# The name of the Supabase bucket to use for storing files. +# SUPABASE_BUCKET_NAME=your-bucket-name -# The api key to use for authenticating with the Supabase service. SUPABASE_API_KEY=your-access-key -# The project endpoint url of the Supabase service. SUPABASE_URL=your-server-url # ------------------------------ @@ -403,28 +382,20 @@ VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT=http://weaviate:8080 -# The Weaviate API key. WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih # The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. QDRANT_URL=http://qdrant:6333 -# The Qdrant API key. QDRANT_API_KEY=difyai123456 -# The Qdrant client timeout setting. QDRANT_CLIENT_TIMEOUT=20 -# The Qdrant client enable gRPC mode. QDRANT_GRPC_ENABLED=false -# The Qdrant server gRPC mode PORT. QDRANT_GRPC_PORT=6334 # Milvus configuration Only available when VECTOR_STORE is `milvus`. # The milvus uri. MILVUS_URI=http://127.0.0.1:19530 -# The milvus token. MILVUS_TOKEN= -# The milvus username. MILVUS_USER=root -# The milvus password. MILVUS_PASSWORD=Milvus # MyScale configuration, only available when VECTOR_STORE is `myscale` @@ -478,8 +449,8 @@ ANALYTICDB_MAX_CONNECTION=5 # TiDB vector configurations, only available when VECTOR_STORE is `tidb` TIDB_VECTOR_HOST=tidb TIDB_VECTOR_PORT=4000 -TIDB_VECTOR_USER=xxx.root -TIDB_VECTOR_PASSWORD=xxxxxx +TIDB_VECTOR_USER= +TIDB_VECTOR_PASSWORD= TIDB_VECTOR_DATABASE=dify # Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant` @@ -502,7 +473,7 @@ CHROMA_PORT=8000 CHROMA_TENANT=default_tenant CHROMA_DATABASE=default_database CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthClientProvider -CHROMA_AUTH_CREDENTIALS=xxxxxx +CHROMA_AUTH_CREDENTIALS= # Oracle configuration, only available when VECTOR_STORE is `oracle` ORACLE_HOST=oracle @@ -539,6 +510,7 @@ ELASTICSEARCH_HOST=0.0.0.0 ELASTICSEARCH_PORT=9200 ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic +KIBANA_PORT=5601 # baidu vector configurations, only available when VECTOR_STORE is `baidu` BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 @@ -558,11 +530,10 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 - # Lindorm configuration, only available when VECTOR_STORE is `lindorm` -LINDORM_URL=http://ld-***************-proxy-search-pub.lindorm.aliyuncs.com:30070 -LINDORM_USERNAME=username -LINDORM_PASSWORD=password +LINDORM_URL=http://lindorm:30070 +LINDORM_USERNAME=lindorm +LINDORM_PASSWORD=lindorm # OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` OCEANBASE_VECTOR_HOST=oceanbase @@ -570,8 +541,13 @@ OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test OCEANBASE_VECTOR_PASSWORD=difyai123456 OCEANBASE_VECTOR_DATABASE=test +OCEANBASE_CLUSTER_NAME=difyai OCEANBASE_MEMORY_LIMIT=6G +# Upstash Vector configuration, only available when VECTOR_STORE is `upstash` +UPSTASH_VECTOR_URL=https://xxx-vector.upstash.io +UPSTASH_VECTOR_TOKEN=dify + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -614,20 +590,16 @@ CODE_GENERATION_MAX_TOKENS=1024 # Multi-modal Configuration # ------------------------------ -# The format of the image/video sent when the multi-modal model is input, +# The format of the image/video/audio/document sent when the multi-modal model is input, # the default is base64, optional url. # The delay of the call in url mode will be lower than that in base64 mode. # It is generally recommended to use the more compatible base64 mode. -# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video. -MULTIMODAL_SEND_IMAGE_FORMAT=base64 -MULTIMODAL_SEND_VIDEO_FORMAT=base64 - +# If configured as url, you need to configure FILES_URL as an externally accessible address so that the multi-modal model can access the image/video/audio/document. +MULTIMODAL_SEND_FORMAT=base64 # Upload image file size limit, default 10M. UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 - # Upload video file size limit, default 100M. UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 - # Upload audio file size limit, default 50M. UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 @@ -635,15 +607,14 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 # Sentry Configuration # Used for application monitoring and error log tracking. # ------------------------------ +SENTRY_DSN= # API Service Sentry DSN address, default is empty, when empty, # all monitoring information is not reported to Sentry. # If not set, Sentry error reporting will be disabled. API_SENTRY_DSN= - # API Service The reporting ratio of Sentry events, if it is 0.01, it is 1%. API_SENTRY_TRACES_SAMPLE_RATE=1.0 - # API Service The reporting ratio of Sentry profiles, if it is 0.01, it is 1%. API_SENTRY_PROFILES_SAMPLE_RATE=1.0 @@ -681,8 +652,10 @@ MAIL_TYPE=resend MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. +RESEND_API_URL=https://api.resend.com RESEND_API_KEY=your-resend-api-key + # SMTP server configuration, used when MAIL_TYPE is `smtp` SMTP_SERVER= SMTP_PORT=465 @@ -707,24 +680,26 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 # The sandbox service endpoint. CODE_EXECUTION_ENDPOINT=http://sandbox:8194 +CODE_EXECUTION_API_KEY=dify-sandbox CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 CODE_MAX_DEPTH=5 CODE_MAX_PRECISION=20 CODE_MAX_STRING_LENGTH=80000 -TEMPLATE_TRANSFORM_MAX_LENGTH=80000 CODE_MAX_STRING_ARRAY_LENGTH=30 CODE_MAX_OBJECT_ARRAY_LENGTH=30 CODE_MAX_NUMBER_ARRAY_LENGTH=1000 CODE_EXECUTION_CONNECT_TIMEOUT=10 CODE_EXECUTION_READ_TIMEOUT=60 CODE_EXECUTION_WRITE_TIMEOUT=10 +TEMPLATE_TRANSFORM_MAX_LENGTH=80000 # Workflow runtime configuration WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 +WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration @@ -944,3 +919,7 @@ CSP_WHITELIST= # Enable or disable create tidb service job CREATE_TIDB_SERVICE_JOB_ENABLED=false + +# Maximum number of submitted thread count in a ThreadPool for parallel node execution +MAX_SUBMIT_COUNT=100 + diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml new file mode 100644 index 0000000000..d4e0ba49d0 --- /dev/null +++ b/docker/docker-compose-template.yaml @@ -0,0 +1,576 @@ +x-shared-env: &shared-api-worker-env +services: + # API service + api: + image: langgenius/dify-api:0.14.2 + restart: always + environment: + # Use the shared environment variables. + <<: *shared-api-worker-env + # Startup mode, 'api' starts the API server. + MODE: api + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + depends_on: + - db + - redis + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default + + # worker service + # The Celery worker for processing the queue. + worker: + image: langgenius/dify-api:0.14.2 + restart: always + environment: + # Use the shared environment variables. + <<: *shared-api-worker-env + # Startup mode, 'worker' starts the Celery worker for processing the queue. + MODE: worker + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + depends_on: + - db + - redis + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/api/storage + networks: + - ssrf_proxy_network + - default + + # Frontend web application. + web: + image: langgenius/dify-web:0.14.2 + restart: always + environment: + CONSOLE_API_URL: ${CONSOLE_API_URL:-} + APP_API_URL: ${APP_API_URL:-} + SENTRY_DSN: ${WEB_SENTRY_DSN:-} + NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} + TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + CSP_WHITELIST: ${CSP_WHITELIST:-} + + # The postgres database. + db: + image: postgres:15-alpine + restart: always + environment: + PGUSER: ${PGUSER:-postgres} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} + POSTGRES_DB: ${POSTGRES_DB:-dify} + PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} + command: > + postgres -c 'max_connections=${POSTGRES_MAX_CONNECTIONS:-100}' + -c 'shared_buffers=${POSTGRES_SHARED_BUFFERS:-128MB}' + -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' + -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' + -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' + volumes: + - ./volumes/db/data:/var/lib/postgresql/data + healthcheck: + test: ['CMD', 'pg_isready'] + interval: 1s + timeout: 3s + retries: 30 + + # The redis cache. + redis: + image: redis:6-alpine + restart: always + environment: + REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456} + volumes: + # Mount the redis data directory to the container. + - ./volumes/redis/data:/data + # Set the redis password when startup redis server. + command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} + healthcheck: + test: ['CMD', 'redis-cli', 'ping'] + + # The DifySandbox + sandbox: + image: langgenius/dify-sandbox:0.2.10 + restart: always + environment: + # The DifySandbox configurations + # Make sure you are changing this key for your deployment with a strong key. + # You can generate a strong key using `openssl rand -base64 42`. + API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} + GIN_MODE: ${SANDBOX_GIN_MODE:-release} + WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} + ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} + HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} + HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} + SANDBOX_PORT: ${SANDBOX_PORT:-8194} + volumes: + - ./volumes/sandbox/dependencies:/dependencies + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] + networks: + - ssrf_proxy_network + + # ssrf_proxy server + # for more information, please refer to + # https://docs.dify.ai/learn-more/faq/install-faq#id-18.-why-is-ssrf_proxy-needed + ssrf_proxy: + image: ubuntu/squid:latest + restart: always + volumes: + - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template + - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh + entrypoint: + [ + 'sh', + '-c', + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] + environment: + # pls clearly modify the squid env vars to fit your network environment. + HTTP_PORT: ${SSRF_HTTP_PORT:-3128} + COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} + REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} + SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} + SANDBOX_PORT: ${SANDBOX_PORT:-8194} + networks: + - ssrf_proxy_network + - default + + # Certbot service + # use `docker-compose --profile certbot up` to start the certbot service. + certbot: + image: certbot/certbot + profiles: + - certbot + volumes: + - ./volumes/certbot/conf:/etc/letsencrypt + - ./volumes/certbot/www:/var/www/html + - ./volumes/certbot/logs:/var/log/letsencrypt + - ./volumes/certbot/conf/live:/etc/letsencrypt/live + - ./certbot/update-cert.template.txt:/update-cert.template.txt + - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh + environment: + - CERTBOT_EMAIL=${CERTBOT_EMAIL} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} + entrypoint: ['/docker-entrypoint.sh'] + command: ['tail', '-f', '/dev/null'] + + # The nginx reverse proxy. + # used for reverse proxying the API service and Web service. + nginx: + image: nginx:latest + restart: always + volumes: + - ./nginx/nginx.conf.template:/etc/nginx/nginx.conf.template + - ./nginx/proxy.conf.template:/etc/nginx/proxy.conf.template + - ./nginx/https.conf.template:/etc/nginx/https.conf.template + - ./nginx/conf.d:/etc/nginx/conf.d + - ./nginx/docker-entrypoint.sh:/docker-entrypoint-mount.sh + - ./nginx/ssl:/etc/ssl # cert dir (legacy) + - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) + - ./volumes/certbot/conf:/etc/letsencrypt + - ./volumes/certbot/www:/var/www/html + entrypoint: + [ + 'sh', + '-c', + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] + environment: + NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} + NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} + NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} + NGINX_PORT: ${NGINX_PORT:-80} + # You're required to add your own SSL certificates/keys to the `./nginx/ssl` directory + # and modify the env vars below in .env if HTTPS_ENABLED is true. + NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} + NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} + NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} + NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} + NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} + NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} + NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} + NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} + CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-} + depends_on: + - api + - web + ports: + - '${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}' + - '${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}' + + # The TiDB vector store. + # For production use, please refer to https://github.com/pingcap/tidb-docker-compose + tidb: + image: pingcap/tidb:v8.4.0 + profiles: + - tidb + command: + - --store=unistore + restart: always + + # The Weaviate vector store. + weaviate: + image: semitechnologies/weaviate:1.19.0 + profiles: + - '' + - weaviate + restart: always + volumes: + # Mount the Weaviate data directory to the con tainer. + - ./volumes/weaviate:/var/lib/weaviate + environment: + # The Weaviate configurations + # You can refer to the [Weaviate](https://weaviate.io/developers/weaviate/config-refs/env-vars) documentation for more information. + PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} + QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: ${WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED:-false} + DEFAULT_VECTORIZER_MODULE: ${WEAVIATE_DEFAULT_VECTORIZER_MODULE:-none} + CLUSTER_HOSTNAME: ${WEAVIATE_CLUSTER_HOSTNAME:-node1} + AUTHENTICATION_APIKEY_ENABLED: ${WEAVIATE_AUTHENTICATION_APIKEY_ENABLED:-true} + AUTHENTICATION_APIKEY_ALLOWED_KEYS: ${WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} + AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} + AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} + AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + + # Qdrant vector store. + # (if used, you need to set VECTOR_STORE to qdrant in the api & worker service.) + qdrant: + image: langgenius/qdrant:v1.7.3 + profiles: + - qdrant + restart: always + volumes: + - ./volumes/qdrant:/qdrant/storage + environment: + QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} + + # The Couchbase vector store. + couchbase-server: + build: ./couchbase-server + profiles: + - couchbase + restart: always + environment: + - CLUSTER_NAME=dify_search + - COUCHBASE_ADMINISTRATOR_USERNAME=${COUCHBASE_USER:-Administrator} + - COUCHBASE_ADMINISTRATOR_PASSWORD=${COUCHBASE_PASSWORD:-password} + - COUCHBASE_BUCKET=${COUCHBASE_BUCKET_NAME:-Embeddings} + - COUCHBASE_BUCKET_RAMSIZE=512 + - COUCHBASE_RAM_SIZE=2048 + - COUCHBASE_EVENTING_RAM_SIZE=512 + - COUCHBASE_INDEX_RAM_SIZE=512 + - COUCHBASE_FTS_RAM_SIZE=1024 + hostname: couchbase-server + container_name: couchbase-server + working_dir: /opt/couchbase + stdin_open: true + tty: true + entrypoint: [""] + command: sh -c "/opt/couchbase/init/init-cbserver.sh" + volumes: + - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data + healthcheck: + # ensure bucket was created before proceeding + test: [ "CMD-SHELL", "curl -s -f -u Administrator:password http://localhost:8091/pools/default/buckets | grep -q '\\[{' || exit 1" ] + interval: 10s + retries: 10 + start_period: 30s + timeout: 10s + + # The pgvector vector database. + pgvector: + image: pgvector/pgvector:pg16 + profiles: + - pgvector + restart: always + environment: + PGUSER: ${PGVECTOR_PGUSER:-postgres} + # The password for the default postgres user. + POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} + # The name of the default postgres database. + POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} + # postgres data directory + PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + volumes: + - ./volumes/pgvector/data:/var/lib/postgresql/data + healthcheck: + test: ['CMD', 'pg_isready'] + interval: 1s + timeout: 3s + retries: 30 + + # pgvecto-rs vector store + pgvecto-rs: + image: tensorchord/pgvecto-rs:pg16-v0.3.0 + profiles: + - pgvecto-rs + restart: always + environment: + PGUSER: ${PGVECTOR_PGUSER:-postgres} + # The password for the default postgres user. + POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} + # The name of the default postgres database. + POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} + # postgres data directory + PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + volumes: + - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data + healthcheck: + test: ['CMD', 'pg_isready'] + interval: 1s + timeout: 3s + retries: 30 + + # Chroma vector database + chroma: + image: ghcr.io/chroma-core/chroma:0.5.20 + profiles: + - chroma + restart: always + volumes: + - ./volumes/chroma:/chroma/chroma + environment: + CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} + CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} + IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + + # OceanBase vector database + oceanbase: + image: quay.io/oceanbase/oceanbase-ce:4.3.3.0-100000142024101215 + profiles: + - oceanbase + restart: always + volumes: + - ./volumes/oceanbase/data:/root/ob + - ./volumes/oceanbase/conf:/root/.obd/cluster + - ./volumes/oceanbase/init.d:/root/boot/init.d + environment: + OB_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + OB_SYS_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_TENANT_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} + OB_SERVER_IP: '127.0.0.1' + + # Oracle vector database + oracle: + image: container-registry.oracle.com/database/free:latest + profiles: + - oracle + restart: always + volumes: + - source: oradata + type: volume + target: /opt/oracle/oradata + - ./startupscripts:/opt/oracle/scripts/startup + environment: + ORACLE_PWD: ${ORACLE_PWD:-Dify123456} + ORACLE_CHARACTERSET: ${ORACLE_CHARACTERSET:-AL32UTF8} + + # Milvus vector database services + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.5 + profiles: + - milvus + environment: + ETCD_AUTO_COMPACTION_MODE: ${ETCD_AUTO_COMPACTION_MODE:-revision} + ETCD_AUTO_COMPACTION_RETENTION: ${ETCD_AUTO_COMPACTION_RETENTION:-1000} + ETCD_QUOTA_BACKEND_BYTES: ${ETCD_QUOTA_BACKEND_BYTES:-4294967296} + ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} + volumes: + - ./volumes/milvus/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + healthcheck: + test: ['CMD', 'etcdctl', 'endpoint', 'health'] + interval: 30s + timeout: 20s + retries: 3 + networks: + - milvus + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + profiles: + - milvus + environment: + MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} + MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} + volumes: + - ./volumes/milvus/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + interval: 30s + timeout: 20s + retries: 3 + networks: + - milvus + + milvus-standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.3.1 + profiles: + - milvus + command: ['milvus', 'run', 'standalone'] + environment: + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} + MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} + common.security.authorizationEnabled: ${MILVUS_AUTHORIZATION_ENABLED:-true} + volumes: + - ./volumes/milvus/milvus:/var/lib/milvus + healthcheck: + test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + depends_on: + - etcd + - minio + ports: + - 19530:19530 + - 9091:9091 + networks: + - milvus + + # Opensearch vector database + opensearch: + container_name: opensearch + image: opensearchproject/opensearch:latest + profiles: + - opensearch + environment: + discovery.type: ${OPENSEARCH_DISCOVERY_TYPE:-single-node} + bootstrap.memory_lock: ${OPENSEARCH_BOOTSTRAP_MEMORY_LOCK:-true} + OPENSEARCH_JAVA_OPTS: -Xms${OPENSEARCH_JAVA_OPTS_MIN:-512m} -Xmx${OPENSEARCH_JAVA_OPTS_MAX:-1024m} + OPENSEARCH_INITIAL_ADMIN_PASSWORD: ${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-Qazwsxedc!@#123} + ulimits: + memlock: + soft: ${OPENSEARCH_MEMLOCK_SOFT:--1} + hard: ${OPENSEARCH_MEMLOCK_HARD:--1} + nofile: + soft: ${OPENSEARCH_NOFILE_SOFT:-65536} + hard: ${OPENSEARCH_NOFILE_HARD:-65536} + volumes: + - ./volumes/opensearch/data:/usr/share/opensearch/data + networks: + - opensearch-net + + opensearch-dashboards: + container_name: opensearch-dashboards + image: opensearchproject/opensearch-dashboards:latest + profiles: + - opensearch + environment: + OPENSEARCH_HOSTS: '["https://opensearch:9200"]' + volumes: + - ./volumes/opensearch/opensearch_dashboards.yml:/usr/share/opensearch-dashboards/config/opensearch_dashboards.yml + networks: + - opensearch-net + depends_on: + - opensearch + + # MyScale vector database + myscale: + container_name: myscale + image: myscale/myscaledb:1.6.4 + profiles: + - myscale + restart: always + tty: true + volumes: + - ./volumes/myscale/data:/var/lib/clickhouse + - ./volumes/myscale/log:/var/log/clickhouse-server + - ./volumes/myscale/config/users.d/custom_users_config.xml:/etc/clickhouse-server/users.d/custom_users_config.xml + ports: + - ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123} + + # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html + # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites + elasticsearch: + image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3 + container_name: elasticsearch + profiles: + - elasticsearch + restart: always + volumes: + - dify_es01_data:/usr/share/elasticsearch/data + environment: + ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + cluster.name: dify-es-cluster + node.name: dify-es0 + discovery.type: single-node + xpack.license.self_generated.type: trial + xpack.security.enabled: 'true' + xpack.security.enrollment.enabled: 'false' + xpack.security.http.ssl.enabled: 'false' + ports: + - ${ELASTICSEARCH_PORT:-9200}:9200 + healthcheck: + test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] + interval: 30s + timeout: 10s + retries: 50 + + # https://www.elastic.co/guide/en/kibana/current/docker.html + # https://www.elastic.co/guide/en/kibana/current/settings.html + kibana: + image: docker.elastic.co/kibana/kibana:8.14.3 + container_name: kibana + profiles: + - elasticsearch + depends_on: + - elasticsearch + restart: always + environment: + XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY: d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa + NO_PROXY: localhost,127.0.0.1,elasticsearch,kibana + XPACK_SECURITY_ENABLED: 'true' + XPACK_SECURITY_ENROLLMENT_ENABLED: 'false' + XPACK_SECURITY_HTTP_SSL_ENABLED: 'false' + XPACK_FLEET_ISAIRGAPPED: 'true' + I18N_LOCALE: zh-CN + SERVER_PORT: '5601' + ELASTICSEARCH_HOSTS: http://elasticsearch:9200 + ports: + - ${KIBANA_PORT:-5601}:5601 + healthcheck: + test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] + interval: 30s + timeout: 10s + retries: 3 + + # unstructured . + # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) + unstructured: + image: downloads.unstructured.io/unstructured-io/unstructured-api:latest + profiles: + - unstructured + restart: always + volumes: + - ./volumes/unstructured:/app/data + +networks: + # create a network between sandbox, api and ssrf_proxy, and can not access outside. + ssrf_proxy_network: + driver: bridge + internal: true + milvus: + driver: bridge + opensearch-net: + driver: bridge + internal: true + +volumes: + oradata: + dify_es01_data: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index cfc3d750c9..7122f4a6d0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,28 +1,34 @@ +# ================================================================== +# WARNING: This file is auto-generated by generate_docker_compose +# Do not modify this file directly. Instead, update the .env.example +# or docker-compose-template.yaml and regenerate this file. +# ================================================================== + x-shared-env: &shared-api-worker-env - WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + CONSOLE_API_URL: ${CONSOLE_API_URL:-} + CONSOLE_WEB_URL: ${CONSOLE_WEB_URL:-} + SERVICE_API_URL: ${SERVICE_API_URL:-} + APP_API_URL: ${APP_API_URL:-} + APP_WEB_URL: ${APP_WEB_URL:-} + FILES_URL: ${FILES_URL:-} LOG_LEVEL: ${LOG_LEVEL:-INFO} - LOG_FILE: ${LOG_FILE:-} + LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} - # Log dateformat - LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S} - # Log Timezone + LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} - CONSOLE_WEB_URL: ${CONSOLE_WEB_URL:-} - CONSOLE_API_URL: ${CONSOLE_API_URL:-} - SERVICE_API_URL: ${SERVICE_API_URL:-} - APP_WEB_URL: ${APP_WEB_URL:-} - CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai} - OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} - FILES_URL: ${FILES_URL:-} - FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} - APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} - MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} + CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"} + OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"} + MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} + FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} + ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} + APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} + APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_PORT: ${DIFY_PORT:-5001} SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} @@ -43,6 +49,11 @@ x-shared-env: &shared-api-worker-env SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} + POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100} + POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB} + POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} + POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB} + POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} REDIS_USERNAME: ${REDIS_USERNAME:-} @@ -55,75 +66,73 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} - REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} + REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} - ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} - CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} + CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} CELERY_SENTINEL_SOCKET_TIMEOUT: ${CELERY_SENTINEL_SOCKET_TIMEOUT:-0.1} WEB_API_CORS_ALLOW_ORIGINS: ${WEB_API_CORS_ALLOW_ORIGINS:-*} CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} - STORAGE_TYPE: ${STORAGE_TYPE:-local} - STORAGE_LOCAL_PATH: ${STORAGE_LOCAL_PATH:-storage} - S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + STORAGE_TYPE: ${STORAGE_TYPE:-opendal} + OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} + OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} S3_ENDPOINT: ${S3_ENDPOINT:-} - S3_BUCKET_NAME: ${S3_BUCKET_NAME:-} + S3_REGION: ${S3_REGION:-us-east-1} + S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} S3_ACCESS_KEY: ${S3_ACCESS_KEY:-} S3_SECRET_KEY: ${S3_SECRET_KEY:-} - S3_REGION: ${S3_REGION:-us-east-1} - AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-} - AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-} - AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-} - AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-} - GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-} - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-} - ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-} - ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-} - ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-} - ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-} - ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-} + S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} + AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} + AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} + AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} + AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"} + GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name} + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string} + ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name} + ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key} + ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key} + ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"} + ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} - ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-} - TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-} - TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-} - TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-} - TENCENT_COS_REGION: ${TENCENT_COS_REGION:-} - TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-} - HUAWEI_OBS_BUCKET_NAME: ${HUAWEI_OBS_BUCKET_NAME:-} - HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-} - HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-} - HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-} - OCI_ENDPOINT: ${OCI_ENDPOINT:-} - OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-} - OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-} - OCI_SECRET_KEY: ${OCI_SECRET_KEY:-} - OCI_REGION: ${OCI_REGION:-} - VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-} - VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-} - VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-} - VOLCENGINE_TOS_ENDPOINT: ${VOLCENGINE_TOS_ENDPOINT:-} - VOLCENGINE_TOS_REGION: ${VOLCENGINE_TOS_REGION:-} - BAIDU_OBS_BUCKET_NAME: ${BAIDU_OBS_BUCKET_NAME:-} - BAIDU_OBS_SECRET_KEY: ${BAIDU_OBS_SECRET_KEY:-} - BAIDU_OBS_ACCESS_KEY: ${BAIDU_OBS_ACCESS_KEY:-} - BAIDU_OBS_ENDPOINT: ${BAIDU_OBS_ENDPOINT:-} + ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} + TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name} + TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key} + TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} + TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} + TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} + OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"} + OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} + OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} + OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key} + OCI_REGION: ${OCI_REGION:-us-ashburn-1} + HUAWEI_OBS_BUCKET_NAME: ${HUAWEI_OBS_BUCKET_NAME:-your-bucket-name} + HUAWEI_OBS_SECRET_KEY: ${HUAWEI_OBS_SECRET_KEY:-your-secret-key} + HUAWEI_OBS_ACCESS_KEY: ${HUAWEI_OBS_ACCESS_KEY:-your-access-key} + HUAWEI_OBS_SERVER: ${HUAWEI_OBS_SERVER:-your-server-url} + VOLCENGINE_TOS_BUCKET_NAME: ${VOLCENGINE_TOS_BUCKET_NAME:-your-bucket-name} + VOLCENGINE_TOS_SECRET_KEY: ${VOLCENGINE_TOS_SECRET_KEY:-your-secret-key} + VOLCENGINE_TOS_ACCESS_KEY: ${VOLCENGINE_TOS_ACCESS_KEY:-your-access-key} + VOLCENGINE_TOS_ENDPOINT: ${VOLCENGINE_TOS_ENDPOINT:-your-server-url} + VOLCENGINE_TOS_REGION: ${VOLCENGINE_TOS_REGION:-your-region} + BAIDU_OBS_BUCKET_NAME: ${BAIDU_OBS_BUCKET_NAME:-your-bucket-name} + BAIDU_OBS_SECRET_KEY: ${BAIDU_OBS_SECRET_KEY:-your-secret-key} + BAIDU_OBS_ACCESS_KEY: ${BAIDU_OBS_ACCESS_KEY:-your-access-key} + BAIDU_OBS_ENDPOINT: ${BAIDU_OBS_ENDPOINT:-your-server-url} + SUPABASE_BUCKET_NAME: ${SUPABASE_BUCKET_NAME:-your-bucket-name} + SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} + SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} - WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} + WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} - QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} + QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-'couchbase-server'} - COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} - COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} - COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} - COUCHBASE_SCOPE_NAME: ${COUCHBASE_SCOPE_NAME:-_default} - MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} + MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} @@ -133,172 +142,264 @@ x-shared-env: &shared-api-worker-env MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-} MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify} MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-} - RELYT_HOST: ${RELYT_HOST:-db} - RELYT_PORT: ${RELYT_PORT:-5432} - RELYT_USER: ${RELYT_USER:-postgres} - RELYT_PASSWORD: ${RELYT_PASSWORD:-difyai123456} - RELYT_DATABASE: ${RELYT_DATABASE:-postgres} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"} + COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} + COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} + COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} + COUCHBASE_SCOPE_NAME: ${COUCHBASE_SCOPE_NAME:-_default} PGVECTOR_HOST: ${PGVECTOR_HOST:-pgvector} PGVECTOR_PORT: ${PGVECTOR_PORT:-5432} PGVECTOR_USER: ${PGVECTOR_USER:-postgres} PGVECTOR_PASSWORD: ${PGVECTOR_PASSWORD:-difyai123456} PGVECTOR_DATABASE: ${PGVECTOR_DATABASE:-dify} + PGVECTOR_MIN_CONNECTION: ${PGVECTOR_MIN_CONNECTION:-1} + PGVECTOR_MAX_CONNECTION: ${PGVECTOR_MAX_CONNECTION:-5} + PGVECTO_RS_HOST: ${PGVECTO_RS_HOST:-pgvecto-rs} + PGVECTO_RS_PORT: ${PGVECTO_RS_PORT:-5432} + PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres} + PGVECTO_RS_PASSWORD: ${PGVECTO_RS_PASSWORD:-difyai123456} + PGVECTO_RS_DATABASE: ${PGVECTO_RS_DATABASE:-dify} + ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-your-ak} + ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-your-sk} + ANALYTICDB_REGION_ID: ${ANALYTICDB_REGION_ID:-cn-hangzhou} + ANALYTICDB_INSTANCE_ID: ${ANALYTICDB_INSTANCE_ID:-gp-ab123456} + ANALYTICDB_ACCOUNT: ${ANALYTICDB_ACCOUNT:-testaccount} + ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-testpassword} + ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} + ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-difypassword} + ANALYTICDB_HOST: ${ANALYTICDB_HOST:-gp-test.aliyuncs.com} + ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} + ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} + ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} TIDB_VECTOR_HOST: ${TIDB_VECTOR_HOST:-tidb} TIDB_VECTOR_PORT: ${TIDB_VECTOR_PORT:-4000} TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} - TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} + TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false} TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334} TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify} TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify} - TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1} - TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1} + TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"} + TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"} TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1} TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify} TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100} - ORACLE_HOST: ${ORACLE_HOST:-oracle} - ORACLE_PORT: ${ORACLE_PORT:-1521} - ORACLE_USER: ${ORACLE_USER:-dify} - ORACLE_PASSWORD: ${ORACLE_PASSWORD:-dify} - ORACLE_DATABASE: ${ORACLE_DATABASE:-FREEPDB1} CHROMA_HOST: ${CHROMA_HOST:-127.0.0.1} CHROMA_PORT: ${CHROMA_PORT:-8000} CHROMA_TENANT: ${CHROMA_TENANT:-default_tenant} CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} - ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} - ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} - ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} - ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} - LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} - LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} - LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm } - KIBANA_PORT: ${KIBANA_PORT:-5601} - # AnalyticDB configuration - ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} - ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} - ANALYTICDB_REGION_ID: ${ANALYTICDB_REGION_ID:-} - ANALYTICDB_INSTANCE_ID: ${ANALYTICDB_INSTANCE_ID:-} - ANALYTICDB_ACCOUNT: ${ANALYTICDB_ACCOUNT:-} - ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-} - ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} - ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-} - ANALYTICDB_HOST: ${ANALYTICDB_HOST:-} - ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} - ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} - ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} + ORACLE_HOST: ${ORACLE_HOST:-oracle} + ORACLE_PORT: ${ORACLE_PORT:-1521} + ORACLE_USER: ${ORACLE_USER:-dify} + ORACLE_PASSWORD: ${ORACLE_PASSWORD:-dify} + ORACLE_DATABASE: ${ORACLE_DATABASE:-FREEPDB1} + RELYT_HOST: ${RELYT_HOST:-db} + RELYT_PORT: ${RELYT_PORT:-5432} + RELYT_USER: ${RELYT_USER:-postgres} + RELYT_PASSWORD: ${RELYT_PASSWORD:-difyai123456} + RELYT_DATABASE: ${RELYT_DATABASE:-postgres} OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} - TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} + TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify} TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify} TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1} TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2} - BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} + ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} + ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} + ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} + ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + KIBANA_PORT: ${KIBANA_PORT:-5601} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} - VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-dify} - VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-dify} + VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak} + VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk} VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai} VIKINGDB_HOST: ${VIKINGDB_HOST:-api-vikingdb.xxx.volces.com} VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} - UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} - UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} - UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} - UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} - ETL_TYPE: ${ETL_TYPE:-dify} - UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} - UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} - PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} - CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} - MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} - MULTIMODAL_SEND_VIDEO_FORMAT: ${MULTIMODAL_SEND_VIDEO_FORMAT:-base64} - UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} - UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} - UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} - SENTRY_DSN: ${API_SENTRY_DSN:-} - SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} - SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} - NOTION_INTEGRATION_TYPE: ${NOTION_INTEGRATION_TYPE:-public} - NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-} - NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-} - NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} - MAIL_TYPE: ${MAIL_TYPE:-resend} - MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} - SMTP_SERVER: ${SMTP_SERVER:-} - SMTP_PORT: ${SMTP_PORT:-465} - SMTP_USERNAME: ${SMTP_USERNAME:-} - SMTP_PASSWORD: ${SMTP_PASSWORD:-} - SMTP_USE_TLS: ${SMTP_USE_TLS:-true} - SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} - RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} - RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} - INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} - RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} - CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} - CODE_EXECUTION_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} - CODE_EXECUTION_CONNECT_TIMEOUT: ${CODE_EXECUTION_CONNECT_TIMEOUT:-10} - CODE_EXECUTION_READ_TIMEOUT: ${CODE_EXECUTION_READ_TIMEOUT:-60} - CODE_EXECUTION_WRITE_TIMEOUT: ${CODE_EXECUTION_WRITE_TIMEOUT:-10} - CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} - CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} - CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} - CODE_MAX_PRECISION: ${CODE_MAX_PRECISION:-20} - CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} - TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} - CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} - CODE_MAX_OBJECT_ARRAY_LENGTH: ${CODE_MAX_OBJECT_ARRAY_LENGTH:-30} - CODE_MAX_NUMBER_ARRAY_LENGTH: ${CODE_MAX_NUMBER_ARRAY_LENGTH:-1000} - WORKFLOW_MAX_EXECUTION_STEPS: ${WORKFLOW_MAX_EXECUTION_STEPS:-500} - WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200} - WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5} - SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} - SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} - HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} - HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} - APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-12000} - POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} - POSITION_TOOL_INCLUDES: ${POSITION_TOOL_INCLUDES:-} - POSITION_TOOL_EXCLUDES: ${POSITION_TOOL_EXCLUDES:-} - POSITION_PROVIDER_PINS: ${POSITION_PROVIDER_PINS:-} - POSITION_PROVIDER_INCLUDES: ${POSITION_PROVIDER_INCLUDES:-} - POSITION_PROVIDER_EXCLUDES: ${POSITION_PROVIDER_EXCLUDES:-} - MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} - OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-http://oceanbase-vector} + VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30} + VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30} + LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"} + LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} + LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} + OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881} OCEANBASE_VECTOR_USER: ${OCEANBASE_VECTOR_USER:-root@test} OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"} + UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} + UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} + UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} + ETL_TYPE: ${ETL_TYPE:-dify} + UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} + UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} + SCARF_NO_ANALYTICS: ${SCARF_NO_ANALYTICS:-true} + PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} + CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} + MULTIMODAL_SEND_FORMAT: ${MULTIMODAL_SEND_FORMAT:-base64} + UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} + UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} + UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} + SENTRY_DSN: ${SENTRY_DSN:-} + API_SENTRY_DSN: ${API_SENTRY_DSN:-} + API_SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + API_SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} + WEB_SENTRY_DSN: ${WEB_SENTRY_DSN:-} + NOTION_INTEGRATION_TYPE: ${NOTION_INTEGRATION_TYPE:-public} + NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-} + NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-} + NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} + MAIL_TYPE: ${MAIL_TYPE:-resend} + MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} + RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"} + RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} + SMTP_SERVER: ${SMTP_SERVER:-} + SMTP_PORT: ${SMTP_PORT:-465} + SMTP_USERNAME: ${SMTP_USERNAME:-} + SMTP_PASSWORD: ${SMTP_PASSWORD:-} + SMTP_USE_TLS: ${SMTP_USE_TLS:-true} + SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} + INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} + RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} + CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"} + CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} + CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} + CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} + CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} + CODE_MAX_PRECISION: ${CODE_MAX_PRECISION:-20} + CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} + CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} + CODE_MAX_OBJECT_ARRAY_LENGTH: ${CODE_MAX_OBJECT_ARRAY_LENGTH:-30} + CODE_MAX_NUMBER_ARRAY_LENGTH: ${CODE_MAX_NUMBER_ARRAY_LENGTH:-1000} + CODE_EXECUTION_CONNECT_TIMEOUT: ${CODE_EXECUTION_CONNECT_TIMEOUT:-10} + CODE_EXECUTION_READ_TIMEOUT: ${CODE_EXECUTION_READ_TIMEOUT:-60} + CODE_EXECUTION_WRITE_TIMEOUT: ${CODE_EXECUTION_WRITE_TIMEOUT:-10} + TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} + WORKFLOW_MAX_EXECUTION_STEPS: ${WORKFLOW_MAX_EXECUTION_STEPS:-500} + WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200} + WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5} + MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} + WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} + WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} + HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} + SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"} + SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"} + TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + PGUSER: ${PGUSER:-${DB_USERNAME}} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} + POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} + PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} + SANDBOX_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} + SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} + SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} + SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} + SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_PORT: ${SANDBOX_PORT:-8194} + WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} + WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} + WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: ${WEAVIATE_AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED:-true} + WEAVIATE_DEFAULT_VECTORIZER_MODULE: ${WEAVIATE_DEFAULT_VECTORIZER_MODULE:-none} + WEAVIATE_CLUSTER_HOSTNAME: ${WEAVIATE_CLUSTER_HOSTNAME:-node1} + WEAVIATE_AUTHENTICATION_APIKEY_ENABLED: ${WEAVIATE_AUTHENTICATION_APIKEY_ENABLED:-true} + WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS: ${WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} + WEAVIATE_AUTHENTICATION_APIKEY_USERS: ${WEAVIATE_AUTHENTICATION_APIKEY_USERS:-hello@dify.ai} + WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED: ${WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED:-true} + WEAVIATE_AUTHORIZATION_ADMINLIST_USERS: ${WEAVIATE_AUTHORIZATION_ADMINLIST_USERS:-hello@dify.ai} + CHROMA_SERVER_AUTHN_CREDENTIALS: ${CHROMA_SERVER_AUTHN_CREDENTIALS:-difyai123456} + CHROMA_SERVER_AUTHN_PROVIDER: ${CHROMA_SERVER_AUTHN_PROVIDER:-chromadb.auth.token_authn.TokenAuthenticationServerProvider} + CHROMA_IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} + ORACLE_PWD: ${ORACLE_PWD:-Dify123456} + ORACLE_CHARACTERSET: ${ORACLE_CHARACTERSET:-AL32UTF8} + ETCD_AUTO_COMPACTION_MODE: ${ETCD_AUTO_COMPACTION_MODE:-revision} + ETCD_AUTO_COMPACTION_RETENTION: ${ETCD_AUTO_COMPACTION_RETENTION:-1000} + ETCD_QUOTA_BACKEND_BYTES: ${ETCD_QUOTA_BACKEND_BYTES:-4294967296} + ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} + MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} + MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"} + MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"} + MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true} + PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres} + PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} + PGVECTOR_POSTGRES_DB: ${PGVECTOR_POSTGRES_DB:-dify} + PGVECTOR_PGDATA: ${PGVECTOR_PGDATA:-/var/lib/postgresql/data/pgdata} + OPENSEARCH_DISCOVERY_TYPE: ${OPENSEARCH_DISCOVERY_TYPE:-single-node} + OPENSEARCH_BOOTSTRAP_MEMORY_LOCK: ${OPENSEARCH_BOOTSTRAP_MEMORY_LOCK:-true} + OPENSEARCH_JAVA_OPTS_MIN: ${OPENSEARCH_JAVA_OPTS_MIN:-512m} + OPENSEARCH_JAVA_OPTS_MAX: ${OPENSEARCH_JAVA_OPTS_MAX:-1024m} + OPENSEARCH_INITIAL_ADMIN_PASSWORD: ${OPENSEARCH_INITIAL_ADMIN_PASSWORD:-Qazwsxedc!@#123} + OPENSEARCH_MEMLOCK_SOFT: ${OPENSEARCH_MEMLOCK_SOFT:--1} + OPENSEARCH_MEMLOCK_HARD: ${OPENSEARCH_MEMLOCK_HARD:--1} + OPENSEARCH_NOFILE_SOFT: ${OPENSEARCH_NOFILE_SOFT:-65536} + OPENSEARCH_NOFILE_HARD: ${OPENSEARCH_NOFILE_HARD:-65536} + NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} + NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} + NGINX_PORT: ${NGINX_PORT:-80} + NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} + NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} + NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"} + NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} + NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} + NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} + NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} + NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} + NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} + CERTBOT_EMAIL: ${CERTBOT_EMAIL:-your_email@example.com} + CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-your_domain.com} + CERTBOT_OPTIONS: ${CERTBOT_OPTIONS:-} + SSRF_HTTP_PORT: ${SSRF_HTTP_PORT:-3128} + SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} + SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} + SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} + COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"} + EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} + EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} + POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} + POSITION_TOOL_INCLUDES: ${POSITION_TOOL_INCLUDES:-} + POSITION_TOOL_EXCLUDES: ${POSITION_TOOL_EXCLUDES:-} + POSITION_PROVIDER_PINS: ${POSITION_PROVIDER_PINS:-} + POSITION_PROVIDER_INCLUDES: ${POSITION_PROVIDER_INCLUDES:-} + POSITION_PROVIDER_EXCLUDES: ${POSITION_PROVIDER_EXCLUDES:-} + CSP_WHITELIST: ${CSP_WHITELIST:-} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} - RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-0} + MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100} services: # API service api: - image: langgenius/dify-api:0.13.2 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env # Startup mode, 'api' starts the API server. MODE: api + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} depends_on: - db - redis @@ -312,13 +413,16 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.13.2 + image: langgenius/dify-api:0.14.2 restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env # Startup mode, 'worker' starts the Celery worker for processing the queue. MODE: worker + SENTRY_DSN: ${API_SENTRY_DSN:-} + SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} + SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} depends_on: - db - redis @@ -331,7 +435,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.13.2 + image: langgenius/dify-web:0.14.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -495,8 +599,8 @@ services: # For production use, please refer to https://github.com/pingcap/tidb-docker-compose tidb: image: pingcap/tidb:v8.4.0 - ports: - - "4000:4000" + profiles: + - tidb command: - --store=unistore restart: always diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose new file mode 100755 index 0000000000..54b6d55217 --- /dev/null +++ b/docker/generate_docker_compose @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +import os +import re +import sys + + +def parse_env_example(file_path): + """ + Parses the .env.example file and returns a dictionary with variable names as keys and default values as values. + """ + env_vars = {} + with open(file_path, "r") as f: + for line_number, line in enumerate(f, 1): + line = line.strip() + # Ignore empty lines and comments + if not line or line.startswith("#"): + continue + # Use regex to parse KEY=VALUE + match = re.match(r"^([^=]+)=(.*)$", line) + if match: + key = match.group(1).strip() + value = match.group(2).strip() + # Remove possible quotes around the value + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + value = value[1:-1] + env_vars[key] = value + else: + print(f"Warning: Unable to parse line {line_number}: {line}") + return env_vars + + +def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"): + """ + Generates a shared environment variables block as a YAML string. + """ + lines = [f"x-shared-env: &{anchor_name}"] + for key, default in env_vars.items(): + # If default value is empty, use ${KEY:-} + if default == "": + lines.append(f" {key}: ${{{key}:-}}") + else: + # If default value contains special characters, wrap it in quotes + if re.search(r"[:\s]", default): + default = f'"{default}"' + lines.append(f" {key}: ${{{key}:-{default}}}") + return "\n".join(lines) + + +def insert_shared_env(template_path, output_path, shared_env_block, header_comments): + """ + Inserts the shared environment variables block and header comments into the template file, + removing any existing x-shared-env anchors, and generates the final docker-compose.yaml file. + """ + with open(template_path, "r") as f: + template_content = f.read() + + # Remove existing x-shared-env: &shared-api-worker-env lines + template_content = re.sub( + r"^x-shared-env: &shared-api-worker-env\s*\n?", + "", + template_content, + flags=re.MULTILINE, + ) + + # Prepare the final content with header comments and shared env block + final_content = f"{header_comments}\n{shared_env_block}\n\n{template_content}" + + with open(output_path, "w") as f: + f.write(final_content) + print(f"Generated {output_path}") + + +def main(): + env_example_path = ".env.example" + template_path = "docker-compose-template.yaml" + output_path = "docker-compose.yaml" + anchor_name = "shared-api-worker-env" # Can be modified as needed + + # Define header comments to be added at the top of docker-compose.yaml + header_comments = ( + "# ==================================================================\n" + "# WARNING: This file is auto-generated by generate_docker_compose\n" + "# Do not modify this file directly. Instead, update the .env.example\n" + "# or docker-compose-template.yaml and regenerate this file.\n" + "# ==================================================================\n" + ) + + # Check if required files exist + for path in [env_example_path, template_path]: + if not os.path.isfile(path): + print(f"Error: File {path} does not exist.") + sys.exit(1) + + # Parse .env.example file + env_vars = parse_env_example(env_example_path) + + if not env_vars: + print("Warning: No environment variables found in .env.example.") + + # Generate shared environment variables block + shared_env_block = generate_shared_env_block(env_vars, anchor_name) + + # Insert shared environment variables block and header comments into the template + insert_shared_env(template_path, output_path, shared_env_block, header_comments) + + +if __name__ == "__main__": + main() diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index e664488301..ee1b5c57e1 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -160,7 +160,10 @@ class WorkflowClient(DifyClient): class KnowledgeBaseClient(DifyClient): def __init__( - self, api_key, base_url: str = "https://api.dify.ai/v1", dataset_id: str = None + self, + api_key, + base_url: str = "https://api.dify.ai/v1", + dataset_id: str | None = None, ): """ Construct a KnowledgeBaseClient object. @@ -187,7 +190,9 @@ class KnowledgeBaseClient(DifyClient): "GET", f"/datasets?page={page}&limit={page_size}", **kwargs ) - def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + def create_document_by_text( + self, name, text, extra_params: dict | None = None, **kwargs + ): """ Create a document by text. @@ -225,7 +230,7 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict = None, **kwargs + self, document_id, name, text, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -262,7 +267,7 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict = None + self, file_path, original_document_id=None, extra_params: dict | None = None ): """ Create a document by file. @@ -304,7 +309,7 @@ class KnowledgeBaseClient(DifyClient): ) def update_document_by_file( - self, document_id, file_path, extra_params: dict = None + self, document_id, file_path, extra_params: dict | None = None ): """ Update a document by file. @@ -372,7 +377,11 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("DELETE", url) def list_documents( - self, page: int = None, page_size: int = None, keyword: str = None, **kwargs + self, + page: int | None = None, + page_size: int | None = None, + keyword: str | None = None, + **kwargs, ): """ Get a list of documents in this dataset. @@ -402,7 +411,11 @@ class KnowledgeBaseClient(DifyClient): return self._send_request("POST", url, json=data, **kwargs) def query_segments( - self, document_id, keyword: str = None, status: str = None, **kwargs + self, + document_id, + keyword: str | None = None, + status: str | None = None, + **kwargs, ): """ Query segments in this document. diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 7a5347c7d5..1d96320309 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -25,6 +25,7 @@ import { fetchAppDetail, fetchAppSSO } from '@/service/apps' import AppContext, { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import type { App } from '@/types/app' export type IAppDetailLayoutProps = { children: React.ReactNode @@ -41,12 +42,14 @@ const AppDetailLayout: FC = (props) => { const pathname = usePathname() const media = useBreakpoints() const isMobile = media === MediaType.mobile - const { isCurrentWorkspaceEditor } = useAppContext() + const { isCurrentWorkspaceEditor, isLoadingCurrentWorkspace } = useAppContext() const { appDetail, setAppDetail, setAppSiderbarExpand } = useStore(useShallow(state => ({ appDetail: state.appDetail, setAppDetail: state.setAppDetail, setAppSiderbarExpand: state.setAppSiderbarExpand, }))) + const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false) + const [appDetailRes, setAppDetailRes] = useState(null) const [navigation, setNavigation] = useState = (props) => { useEffect(() => { setAppDetail() + setIsLoadingAppDetail(true) fetchAppDetail({ url: '/apps', id: appId }).then((res) => { - // redirection - const canIEditApp = isCurrentWorkspaceEditor - if (!canIEditApp && (pathname.endsWith('configuration') || pathname.endsWith('workflow') || pathname.endsWith('logs'))) { - router.replace(`/app/${appId}/overview`) - return - } - if ((res.mode === 'workflow' || res.mode === 'advanced-chat') && (pathname).endsWith('configuration')) { - router.replace(`/app/${appId}/workflow`) - } - else if ((res.mode !== 'workflow' && res.mode !== 'advanced-chat') && (pathname).endsWith('workflow')) { - router.replace(`/app/${appId}/configuration`) - } - else { - setAppDetail({ ...res, enable_sso: false }) - setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) - if (systemFeatures.enable_web_sso_switch_component && canIEditApp) { - fetchAppSSO({ appId }).then((ssoRes) => { - setAppDetail({ ...res, enable_sso: ssoRes.enabled }) - }) - } - } + setAppDetailRes(res) }).catch((e: any) => { if (e.status === 404) router.replace('/apps') + }).finally(() => { + setIsLoadingAppDetail(false) }) - }, [appId, isCurrentWorkspaceEditor, systemFeatures, getNavigations, pathname, router, setAppDetail]) + }, [appId, router, setAppDetail]) + + useEffect(() => { + if (!appDetailRes || isLoadingCurrentWorkspace || isLoadingAppDetail) + return + const res = appDetailRes + // redirection + const canIEditApp = isCurrentWorkspaceEditor + if (!canIEditApp && (pathname.endsWith('configuration') || pathname.endsWith('workflow') || pathname.endsWith('logs'))) { + router.replace(`/app/${appId}/overview`) + return + } + if ((res.mode === 'workflow' || res.mode === 'advanced-chat') && (pathname).endsWith('configuration')) { + router.replace(`/app/${appId}/workflow`) + } + else if ((res.mode !== 'workflow' && res.mode !== 'advanced-chat') && (pathname).endsWith('workflow')) { + router.replace(`/app/${appId}/configuration`) + } + else { + setAppDetail({ ...res, enable_sso: false }) + setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) + if (systemFeatures.enable_web_sso_switch_component && canIEditApp) { + fetchAppSSO({ appId }).then((ssoRes) => { + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + }) + } + } + }, [appDetailRes, appId, getNavigations, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace, pathname, router, setAppDetail, systemFeatures.enable_web_sso_switch_component]) useUnmount(() => { setAppDetail() diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/(commonLayout)/apps/NewAppCard.tsx index d353cf2394..a90af4ea85 100644 --- a/web/app/(commonLayout)/apps/NewAppCard.tsx +++ b/web/app/(commonLayout)/apps/NewAppCard.tsx @@ -18,7 +18,6 @@ export type CreateAppCardProps = { onSuccess?: () => void } -// eslint-disable-next-line react/display-name const CreateAppCard = forwardRef(({ className, onSuccess }, ref) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() @@ -44,24 +43,22 @@ const CreateAppCard = forwardRef(({ classNam >
{t('app.createApp')}
-
setShowNewAppModal(true)}> +
-
setShowNewAppTemplateDialog(true)}> + +
-
-
setShowCreateFromDSLModal(true)} - > -
+ +
+
+ setShowNewAppModal(false)} @@ -108,4 +105,6 @@ const CreateAppCard = forwardRef(({ classNam ) }) +CreateAppCard.displayName = 'CreateAppCard' export default CreateAppCard +export { CreateAppCard } diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index 71540ce3b1..c7af05793f 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -18,10 +18,10 @@ import { IS_CE_EDITION } from '@/config' import Input from '@/app/components/base/input' const titleClassName = ` - text-sm font-medium text-gray-900 + system-sm-semibold text-text-secondary ` const descriptionClassName = ` - mt-1 text-xs font-normal text-gray-500 + mt-1 body-xs-regular text-text-tertiary ` const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ @@ -122,7 +122,7 @@ export default function AccountPage() {
-
{item.name}
+
{item.name}
) } @@ -130,7 +130,7 @@ export default function AccountPage() { return ( <>
-

{t('common.account.myAccount')}

+

{t('common.account.myAccount')}

@@ -142,10 +142,10 @@ export default function AccountPage() {
{t('common.account.name')}
-
+
{userProfile.name}
-
+
{t('common.operation.edit')}
@@ -153,7 +153,7 @@ export default function AccountPage() {
{t('common.account.email')}
-
+
{userProfile.email}
@@ -162,14 +162,14 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('common.account.password')}
-
{t('common.account.passwordTip')}
+
{t('common.account.password')}
+
{t('common.account.passwordTip')}
) } -
+
{t('common.account.langGeniusAccount')}
{t('common.account.langGeniusAccountTip')}
@@ -181,7 +181,7 @@ export default function AccountPage() { wrapperClassName='mt-2' /> )} - {!IS_CE_EDITION && } + {!IS_CE_EDITION && }
{ editNameModalVisible && ( @@ -190,7 +190,7 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className={s.modal} > -
{t('common.account.editName')}
+
{t('common.account.editName')}
{t('common.account.name')}
-
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
{userProfile.is_password_set && ( <>
{t('common.account.currentPassword')}
@@ -242,7 +242,7 @@ export default function AccountPage() {
)} -
+
{userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')}
@@ -261,7 +261,7 @@ export default function AccountPage() {
-
{t('common.account.confirmPassword')}
+
{t('common.account.confirmPassword')}
-
+
{t('common.account.deleteTip')}
{t('common.account.deleteConfirmTip')}
-
{`${t('common.account.delete')}: ${userProfile.email}`}
+
{`${t('common.account.delete')}: ${userProfile.email}`}
} confirmText={t('common.operation.ok') as string} diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx index 544e43ab27..8fdecc07bf 100644 --- a/web/app/account/avatar.tsx +++ b/web/app/account/avatar.tsx @@ -40,9 +40,9 @@ export default function AppSelector() { className={` inline-flex items-center rounded-[20px] p-1x text-sm - text-gray-700 hover:bg-gray-200 + text-text-primary mobile:px-1 - ${open && 'bg-gray-200'} + ${open && 'bg-components-panel-bg-blur'} `} > @@ -60,7 +60,7 @@ export default function AppSelector() { @@ -78,10 +78,10 @@ export default function AppSelector() {
handleLogout()}>
- -
{t('common.userProfile.logout')}
+ +
{t('common.userProfile.logout')}
diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx index 5aa8b05cbf..11a6abeab4 100644 --- a/web/app/account/layout.tsx +++ b/web/app/account/layout.tsx @@ -21,7 +21,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
-
+
{children}
diff --git a/web/app/components/app/create-app-dialog/app-card/index.tsx b/web/app/components/app/create-app-dialog/app-card/index.tsx index 254d67c923..f1807941ee 100644 --- a/web/app/components/app/create-app-dialog/app-card/index.tsx +++ b/web/app/components/app/create-app-dialog/app-card/index.tsx @@ -25,10 +25,10 @@ const AppCard = ({
diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 8f064c209e..383aeb1492 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -16,6 +16,7 @@ import { createContext, useContext } from 'use-context-selector' import { useShallow } from 'zustand/react/shallow' import { useTranslation } from 'react-i18next' import type { ChatItemInTree } from '../../base/chat/types' +import Indicator from '../../header/indicator' import VarPanel from './var-panel' import type { FeedbackFunc, FeedbackType, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type' import type { Annotation, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log' @@ -57,6 +58,12 @@ type IDrawerContext = { appDetail?: App } +type StatusCount = { + success: number + failed: number + partial_success: number +} + const DrawerContext = createContext({} as IDrawerContext) /** @@ -71,6 +78,33 @@ const HandThumbIconWithCount: FC<{ count: number; iconType: 'up' | 'down' }> = (
} +const statusTdRender = (statusCount: StatusCount) => { + if (statusCount.partial_success + statusCount.failed === 0) { + return ( +
+ + Success +
+ ) + } + else if (statusCount.failed === 0) { + return ( +
+ + Partial Success +
+ ) + } + else { + return ( +
+ + {statusCount.failed} {`${statusCount.failed > 1 ? 'Failures' : 'Failure'}`} +
+ ) + } +} + const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => { const newChatList: IChatItem[] = [] messages.forEach((item: ChatMessage) => { @@ -496,8 +530,8 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { } /** - * Text App Conversation Detail Component - */ + * Text App Conversation Detail Component + */ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => { // Text Generator App Session Details Including Message List const detailParams = ({ url: `/apps/${appId}/completion-conversations/${conversationId}` }) @@ -542,8 +576,8 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st } /** - * Chat App Conversation Detail Component - */ + * Chat App Conversation Detail Component + */ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => { const detailParams = { url: `/apps/${appId}/chat-conversations/${conversationId}` } const { data: conversationDetail } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchChatConversationDetail) @@ -585,8 +619,8 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string } } /** - * Conversation list component including basic information - */ + * Conversation list component including basic information + */ const ConversationList: FC = ({ logs, appDetail, onRefresh }) => { const { t } = useTranslation() const { formatTime } = useTimestamp() @@ -597,6 +631,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) const [showDrawer, setShowDrawer] = useState(false) // Whether to display the chat details drawer const [currentConversation, setCurrentConversation] = useState() // Currently selected conversation const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app + const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app const { setShowPromptLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({ setShowPromptLogModal: state.setShowPromptLogModal, setShowAgentLogModal: state.setShowAgentLogModal, @@ -639,6 +674,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) {isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')} {t('appLog.table.header.endUser')} + {isChatflow && {t('appLog.table.header.status')}} {isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')} {t('appLog.table.header.userRate')} {t('appLog.table.header.adminRate')} @@ -669,6 +705,9 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) {renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)} {renderTdValue(endUser || defaultValue, !endUser)} + {isChatflow && + {statusTdRender(log.status_count)} + } {renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)} diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index e3de4a957f..41db9b5d46 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -63,6 +63,14 @@ const WorkflowAppLogList: FC = ({ logs, appDetail, onRefresh }) => {
) } + if (status === 'partial-succeeded') { + return ( +
+ + Partial Success +
+ ) + } } const onCloseDrawer = () => { diff --git a/web/app/components/base/app-icon/index.tsx b/web/app/components/base/app-icon/index.tsx index c195b7253d..1938c42d3e 100644 --- a/web/app/components/base/app-icon/index.tsx +++ b/web/app/components/base/app-icon/index.tsx @@ -3,7 +3,6 @@ import type { FC } from 'react' import { init } from 'emoji-mart' import data from '@emoji-mart/data' -import Image from 'next/image' import { cva } from 'class-variance-authority' import type { AppIconType } from '@/types/app' import classNames from '@/utils/classnames' @@ -62,7 +61,8 @@ const AppIcon: FC = ({ onClick={onClick} > {isValidImageIcon - ? app icon + // eslint-disable-next-line @next/next/no-img-element + ? app icon : (innerIcon || ((icon && icon !== '') ? : )) } diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 4a09e27d98..bb9abdb6fc 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -64,6 +64,12 @@ const WorkflowProcessItem = ({ setShowMessageLogModal(true) }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + const showRetryDetail = useCallback(() => { + setCurrentLogItem(item) + setCurrentLogModalActiveTab('TRACING') + setShowMessageLogModal(true) + }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + return (
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index bf8efdb65a..044fc27858 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -28,6 +28,7 @@ export type InputProps = { destructive?: boolean wrapperClassName?: string styleCss?: CSSProperties + unit?: string } & React.InputHTMLAttributes & VariantProps const Input = ({ @@ -43,6 +44,7 @@ const Input = ({ value, placeholder, onChange, + unit, ...props }: InputProps) => { const { t } = useTranslation() @@ -80,6 +82,13 @@ const Input = ({ {destructive && ( )} + { + unit && ( +
+ {unit} +
+ ) + }
) } diff --git a/web/app/components/base/modal/index.tsx b/web/app/components/base/modal/index.tsx index 5b8c4be4b8..26cde5fce3 100644 --- a/web/app/components/base/modal/index.tsx +++ b/web/app/components/base/modal/index.tsx @@ -1,6 +1,6 @@ import { Dialog, Transition } from '@headlessui/react' import { Fragment } from 'react' -import { XMarkIcon } from '@heroicons/react/24/outline' +import { RiCloseLine } from '@remixicon/react' import classNames from '@/utils/classnames' // https://headlessui.com/react/dialog @@ -39,7 +39,7 @@ export default function Modal({ leaveFrom="opacity-100" leaveTo="opacity-0" > -
+
{title && {title} } - {description && + {description && {description} } {closable - &&
- + { e.stopPropagation() onClose() diff --git a/web/app/components/base/search-input/index.tsx b/web/app/components/base/search-input/index.tsx index 89345fbe32..556a7bdf49 100644 --- a/web/app/components/base/search-input/index.tsx +++ b/web/app/components/base/search-input/index.tsx @@ -23,6 +23,7 @@ const SearchInput: FC = ({ const { t } = useTranslation() const [focus, setFocus] = useState(false) const isComposing = useRef(false) + const [internalValue, setInternalValue] = useState(value) return (
= ({ white && '!bg-white hover:!bg-white group-hover:!bg-white placeholder:!text-gray-400', )} placeholder={placeholder || t('common.operation.search')!} - value={value} + value={internalValue} onChange={(e) => { + setInternalValue(e.target.value) if (!isComposing.current) onChange(e.target.value) }} onCompositionStart={() => { isComposing.current = true }} - onCompositionEnd={() => { + onCompositionEnd={(e) => { isComposing.current = false + onChange(e.data) }} onFocus={() => setFocus(true)} onBlur={() => setFocus(false)} @@ -63,7 +66,10 @@ const SearchInput: FC = ({ {value && (
onChange('')} + onClick={() => { + onChange('') + setInternalValue('') + }} >
diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index c70cf24661..221d70355f 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -2,7 +2,8 @@ import type { FC } from 'react' import React, { Fragment, useEffect, useState } from 'react' import { Combobox, Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' +import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid' +import { RiCheckLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import classNames from '@/utils/classnames' import { @@ -152,7 +153,7 @@ const Select: FC = ({ 'absolute inset-y-0 right-0 flex items-center pr-4 text-gray-700', )} > -
### 基础 URL @@ -68,6 +68,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `image` 具体类型包含:'JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG' - `audio` 具体类型包含:'MP3', 'M4A', 'WAV', 'WEBM', 'AMR' - `video` 具体类型包含:'MP4', 'MOV', 'MPEG', 'MPGA' + - `custom` 具体类型包含:其他文件类型 - `transfer_method` (string) 传递方式: - `remote_url`: 图片地址。 - `local_file`: 上传文件。 @@ -450,6 +451,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -457,7 +461,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -465,7 +469,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 4e873b3294..d38e80407a 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -408,6 +408,9 @@ Chat applications support session persistence, allowing previous chat history to User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -415,7 +418,7 @@ Chat applications support session persistence, allowing previous chat history to - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -423,7 +426,8 @@ Chat applications support session persistence, allowing previous chat history to --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` @@ -709,7 +713,7 @@ Chat applications support session persistence, allowing previous chat history to - + ```bash {{ title: 'cURL' }} curl -X GET '${props.appDetail.api_base_url}/conversations?user=abc-123&last_id=&limit=20' \ diff --git a/web/app/components/develop/template/template_chat.ja.mdx b/web/app/components/develop/template/template_chat.ja.mdx index b8914a4749..96db9912d5 100644 --- a/web/app/components/develop/template/template_chat.ja.mdx +++ b/web/app/components/develop/template/template_chat.ja.mdx @@ -408,6 +408,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。 + + メッセージのフィードバックです。 + ### 応答 @@ -415,7 +418,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -423,7 +426,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` @@ -708,7 +712,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X GET '${props.appDetail.api_base_url}/conversations?user=abc-123&last_id=&limit=20' \ diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 70242623b7..3d6e3630be 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -3,7 +3,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' # 对话型应用 API -对话应用支持会话持久化,可将之前的聊天记录作为上下进行回答,可适用于聊天/客服 AI 等。 +对话应用支持会话持久化,可将之前的聊天记录作为上下文进行回答,可适用于聊天/客服 AI 等。
### 基础 URL @@ -423,6 +423,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -430,7 +433,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -438,7 +441,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_workflow.en.mdx b/web/app/components/develop/template/template_workflow.en.mdx index cfa5a60d47..58c533c60b 100644 --- a/web/app/components/develop/template/template_workflow.en.mdx +++ b/web/app/components/develop/template/template_workflow.en.mdx @@ -60,6 +60,7 @@ Workflow applications offers non-session support and is ideal for translation, a - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') - `audio` ('MP3', 'M4A', 'WAV', 'WEBM', 'AMR') - `video` ('MP4', 'MOV', 'MPEG', 'MPGA') + - `custom` (Other file types) - `transfer_method` (string) Transfer method, `remote_url` for image URL / `local_file` for file upload - `url` (string) Image URL (when the transfer method is `remote_url`) - `upload_file_id` (string) Uploaded file ID, which must be obtained by uploading through the File Upload API in advance (when the transfer method is `local_file`) diff --git a/web/app/components/develop/template/template_workflow.ja.mdx b/web/app/components/develop/template/template_workflow.ja.mdx index b6f8fb543f..2653b4913d 100644 --- a/web/app/components/develop/template/template_workflow.ja.mdx +++ b/web/app/components/develop/template/template_workflow.ja.mdx @@ -60,6 +60,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') - `audio` ('MP3', 'M4A', 'WAV', 'WEBM', 'AMR') - `video` ('MP4', 'MOV', 'MPEG', 'MPGA') + - `custom` (他のファイルタイプ) - `transfer_method` (string) 転送方法、画像URLの場合は`remote_url` / ファイルアップロードの場合は`local_file` - `url` (string) 画像URL(転送方法が`remote_url`の場合) - `upload_file_id` (string) アップロードされたファイルID、事前にファイルアップロードAPIを通じて取得する必要があります(転送方法が`local_file`の場合) diff --git a/web/app/components/develop/template/template_workflow.zh.mdx b/web/app/components/develop/template/template_workflow.zh.mdx index 9cef3d18a5..ddffc0f02d 100644 --- a/web/app/components/develop/template/template_workflow.zh.mdx +++ b/web/app/components/develop/template/template_workflow.zh.mdx @@ -58,6 +58,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 - `image` 具体类型包含:'JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG' - `audio` 具体类型包含:'MP3', 'M4A', 'WAV', 'WEBM', 'AMR' - `video` 具体类型包含:'MP4', 'MOV', 'MPEG', 'MPGA' + - `custom` 具体类型包含:其他文件类型 - `transfer_method` (string) 传递方式,`remote_url` 图片地址 / `local_file` 上传文件 - `url` (string) 图片地址(仅当传递方式为 `remote_url` 时) - `upload_file_id` (string) (string) 上传文件 ID(仅当传递方式为 `local_file` 时) diff --git a/web/app/components/explore/app-card/index.tsx b/web/app/components/explore/app-card/index.tsx index b1ea4a95bf..f1826395f7 100644 --- a/web/app/components/explore/app-card/index.tsx +++ b/web/app/components/explore/app-card/index.tsx @@ -28,10 +28,10 @@ const AppCard = ({
{appBasicInfo.mode === 'advanced-chat' && ( diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 14f079c0f2..a30ab175ac 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -3,7 +3,7 @@ import { useTranslation } from 'react-i18next' import { Fragment, useState } from 'react' import { useRouter } from 'next/navigation' import { useContext } from 'use-context-selector' -import { RiArrowDownSLine } from '@remixicon/react' +import { RiArrowDownSLine, RiLogoutBoxRLine } from '@remixicon/react' import Link from 'next/link' import { Menu, Transition } from '@headlessui/react' import Indicator from '../indicator' @@ -16,7 +16,6 @@ import Avatar from '@/app/components/base/avatar' import { logout } from '@/service/common' import { useAppContext } from '@/context/app-context' import { ArrowUpRight } from '@/app/components/base/icons/src/vender/line/arrows' -import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import { useModalContext } from '@/context/modal-context' import { LanguagesSupported } from '@/i18n/language' import { useProviderContext } from '@/context/provider-context' @@ -28,8 +27,8 @@ export type IAppSelector = { export default function AppSelector({ isMobile }: IAppSelector) { const itemClassName = ` - flex items-center w-full h-9 px-3 text-gray-700 text-[14px] - rounded-lg font-normal hover:bg-gray-50 cursor-pointer + flex items-center w-full h-9 px-3 text-text-secondary system-md-regular + rounded-lg hover:bg-state-base-hover cursor-pointer ` const router = useRouter() const [aboutVisible, setAboutVisible] = useState(false) @@ -89,7 +88,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { @@ -97,13 +96,13 @@ export default function AppSelector({ isMobile }: IAppSelector) {
-
{userProfile.name}
-
{userProfile.email}
+
{userProfile.name}
+
{userProfile.email}
-
{t('common.userProfile.workspace')}
+
{t('common.userProfile.workspace')}
@@ -113,7 +112,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='/account' target='_self' rel='noopener noreferrer'>
{t('common.account.account')}
- + @@ -127,7 +126,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href={mailToSupport(userProfile.email, plan.type, langeniusVersionInfo.current_version)} target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.emailSupport')}
- +
} @@ -136,7 +135,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='https://github.com/langgenius/dify/discussions/categories/feedbacks' target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.communityFeedback')}
- +
@@ -145,7 +144,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='https://discord.gg/5AEfbxcd9k' target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.community')}
- +
@@ -156,7 +155,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { } target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.helpCenter')}
- +
@@ -165,7 +164,7 @@ export default function AppSelector({ isMobile }: IAppSelector) { href='https://roadmap.dify.ai' target='_blank' rel='noopener noreferrer'>
{t('common.userProfile.roadmap')}
- +
{ @@ -174,7 +173,7 @@ export default function AppSelector({ isMobile }: IAppSelector) {
setAboutVisible(true)}>
{t('common.userProfile.about')}
-
{langeniusVersionInfo.current_version}
+
{langeniusVersionInfo.current_version}
@@ -185,10 +184,10 @@ export default function AppSelector({ isMobile }: IAppSelector) {
handleLogout()}>
-
{t('common.userProfile.logout')}
- +
{t('common.userProfile.logout')}
+
diff --git a/web/app/components/header/account-setting/collapse/index.tsx b/web/app/components/header/account-setting/collapse/index.tsx index a70dca16e5..d0068dabed 100644 --- a/web/app/components/header/account-setting/collapse/index.tsx +++ b/web/app/components/header/account-setting/collapse/index.tsx @@ -25,18 +25,18 @@ const Collapse = ({ const toggle = () => setOpen(!open) return ( -
-
+
+
{title} { open - ? - : + ? + : }
{ open && ( -
+
{ items.map(item => (
onSelect && onSelect(item)}> diff --git a/web/app/components/header/account-setting/data-source-page/index.tsx b/web/app/components/header/account-setting/data-source-page/index.tsx index c3da977ca4..93dc2db854 100644 --- a/web/app/components/header/account-setting/data-source-page/index.tsx +++ b/web/app/components/header/account-setting/data-source-page/index.tsx @@ -12,7 +12,6 @@ export default function DataSourcePage() { return (
-
{t('common.dataSource.add')}
diff --git a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx b/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx index 2a05808e2a..b7fd8193e2 100644 --- a/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx +++ b/web/app/components/header/account-setting/data-source-page/panel/config-item.tsx @@ -44,22 +44,22 @@ const ConfigItem: FC = ({ const onChangeAuthorizedPage = notionActions?.onChangeAuthorizedPage || function () { } return ( -
+
-
{payload.name}
+
{payload.name}
{ payload.isActive - ? + ? : } -
+
{ payload.isActive ? t(isNotion ? 'common.dataSource.notion.connected' : 'common.dataSource.website.active') : t(isNotion ? 'common.dataSource.notion.disconnected' : 'common.dataSource.website.inactive') }
-
+
{isNotion && ( = ({ { isWebsite && !readOnly && ( -
- +
+
) } diff --git a/web/app/components/header/account-setting/data-source-page/panel/index.tsx b/web/app/components/header/account-setting/data-source-page/panel/index.tsx index 4a810020b4..8d2ec0a8ca 100644 --- a/web/app/components/header/account-setting/data-source-page/panel/index.tsx +++ b/web/app/components/header/account-setting/data-source-page/panel/index.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { PlusIcon } from '@heroicons/react/24/solid' +import { RiAddLine } from '@remixicon/react' import type { ConfigItemType } from './config-item' import ConfigItem from './config-item' @@ -41,12 +41,12 @@ const Panel: FC = ({ const isWebsite = type === DataSourceType.website return ( -
+
-
+
-
{t(`common.dataSource.${type}.title`)}
+
{t(`common.dataSource.${type}.title`)}
{isWebsite && (
{t('common.dataSource.website.with')} { provider === DataSourceProvider.fireCrawl ? '🔥 Firecrawl' : 'Jina Reader'} @@ -55,7 +55,7 @@ const Panel: FC = ({
{ !isConfigured && ( -
+
{t(`common.dataSource.${type}.description`)}
) @@ -81,13 +81,13 @@ const Panel: FC = ({ <> {isSupportList &&
- - {t('common.dataSource.notion.addWorkspace')} + + {t('common.dataSource.connect')}
} ) @@ -98,8 +98,8 @@ const Panel: FC = ({ {isWebsite && !isConfigured && (
= ({ isConfigured && ( <>
-
+
{isNotion ? t('common.dataSource.notion.connectedWorkspace') : t('common.dataSource.website.configuredCrawlers')}
-
+
{ diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index d829f6b77b..4be7ec6ab7 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -152,14 +152,14 @@ export default function AccountSetting({ wrapperClassName='pt-[60px]' >
-
-
{t('common.userProfile.settings')}
+
+
{t('common.userProfile.settings')}
{ menuItems.map(menuItem => (
{!isCurrentWorkspaceDatasetOperator && ( -
{menuItem.name}
+
{menuItem.name}
)}
{ @@ -168,7 +168,7 @@ export default function AccountSetting({ key={item.key} className={` flex items-center h-[37px] mb-[2px] text-sm cursor-pointer rounded-lg - ${activeMenu === item.key ? 'font-semibold text-primary-600 bg-primary-50' : 'font-light text-gray-700'} + ${activeMenu === item.key ? 'system-sm-semibold text-components-menu-item-text-active bg-state-base-active' : 'system-sm-medium text-components-menu-item-text'} `} title={item.name} onClick={() => setActiveMenu(item.key)} @@ -185,7 +185,7 @@ export default function AccountSetting({
-
+
{activeItem?.name}
{ activeItem?.description && ( @@ -193,8 +193,8 @@ export default function AccountSetting({ ) }
-
- +
+
diff --git a/web/app/components/header/account-setting/language-page/index.tsx b/web/app/components/header/account-setting/language-page/index.tsx index fc8db86813..7d3e09fc21 100644 --- a/web/app/components/header/account-setting/language-page/index.tsx +++ b/web/app/components/header/account-setting/language-page/index.tsx @@ -13,7 +13,7 @@ import { timezones } from '@/utils/timezone' import { languages } from '@/i18n/language' const titleClassName = ` - mb-2 text-sm font-medium text-gray-900 + mb-2 system-sm-semibold text-text-secondary ` export default function LanguagePage() { diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index 03d65af7a4..808da454d1 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -34,7 +34,7 @@ const MembersPage = () => { } const { locale } = useContext(I18n) - const { userProfile, currentWorkspace, isCurrentWorkspaceOwner, isCurrentWorkspaceManager } = useAppContext() + const { userProfile, currentWorkspace, isCurrentWorkspaceOwner, isCurrentWorkspaceManager, systemFeatures } = useAppContext() const { data, mutate } = useSWR({ url: '/workspaces/current/members' }, fetchMembers) const [inviteModalVisible, setInviteModalVisible] = useState(false) const [invitationResults, setInvitationResults] = useState([]) @@ -85,32 +85,32 @@ const MembersPage = () => {
-
-
{t('common.members.name')}
-
{t('common.members.lastActive')}
-
{t('common.members.role')}
+
+
{t('common.members.name')}
+
{t('common.members.lastActive')}
+
{t('common.members.role')}
{ accounts.map(account => ( -
+
-
+
{account.name} - {account.status === 'pending' && {t('common.members.pending')}} - {userProfile.email === account.email && {t('common.members.you')}} + {account.status === 'pending' && {t('common.members.pending')}} + {userProfile.email === account.email && {t('common.members.you')}}
-
{account.email}
+
{account.email}
-
{dayjs(Number((account.last_active_at || account.created_at)) * 1000).locale(locale === 'zh-Hans' ? 'zh-cn' : 'en').fromNow()}
+
{dayjs(Number((account.last_active_at || account.created_at)) * 1000).locale(locale === 'zh-Hans' ? 'zh-cn' : 'en').fromNow()}
{ ((isCurrentWorkspaceOwner && account.role !== 'owner') || (isCurrentWorkspaceManager && !['owner', 'admin'].includes(account.role))) ? - :
{RoleMap[account.role] || RoleMap.normal}
+ :
{RoleMap[account.role] || RoleMap.normal}
}
@@ -122,6 +122,7 @@ const MembersPage = () => { { inviteModalVisible && ( setInviteModalVisible(false)} onSend={(invitationResults) => { setInvitedModalVisible(true) diff --git a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx index 7d43495362..197e3ee867 100644 --- a/web/app/components/header/account-setting/members-page/invite-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/invite-modal/index.tsx @@ -4,6 +4,7 @@ import { useContext } from 'use-context-selector' import { XMarkIcon } from '@heroicons/react/24/outline' import { useTranslation } from 'react-i18next' import { ReactMultiEmail } from 'react-multi-email' +import { RiErrorWarningFill } from '@remixicon/react' import RoleSelector from './role-selector' import s from './index.module.css' import cn from '@/utils/classnames' @@ -17,11 +18,13 @@ import I18n from '@/context/i18n' import 'react-multi-email/dist/style.css' type IInviteModalProps = { + isEmailSetup: boolean onCancel: () => void onSend: (invitationResults: InvitationResult[]) => void } const InviteModal = ({ + isEmailSetup, onCancel, onSend, }: IInviteModalProps) => { @@ -59,7 +62,23 @@ const InviteModal = ({
{t('common.members.inviteTeamMember')}
-
{t('common.members.inviteTeamMemberTip')}
+
{t('common.members.inviteTeamMemberTip')}
+ {!isEmailSetup && ( +
+
+
+
+
+ +
+
+ {t('common.members.emailNotSetup')} +
+
+
+
+ )} +
{t('common.members.email')}
diff --git a/web/app/components/header/header-wrapper.tsx b/web/app/components/header/header-wrapper.tsx index 52728bea87..dd0ec77b82 100644 --- a/web/app/components/header/header-wrapper.tsx +++ b/web/app/components/header/header-wrapper.tsx @@ -11,7 +11,7 @@ const HeaderWrapper = ({ children, }: HeaderWrapperProps) => { const pathname = usePathname() - const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools', '/account'].includes(pathname) + const isBordered = ['/apps', '/datasets', '/datasets/create', '/tools'].includes(pathname) return (
{ // eslint-disable-next-line react-hooks/exhaustive-deps }, [selectedSegment]) return ( -
+
{isMobile &&
{ const { t } = useTranslation() + const searchParams = useSearchParams() return (
-
{t('tools.addToolModal.emptyTitle')}
-
{t('tools.addToolModal.emptyTip')}
+
+ {t(`tools.addToolModal.${searchParams.get('category') === 'workflow' ? 'emptyTitle' : 'emptyTitleCustom'}`)} +
+
+ {t(`tools.addToolModal.${searchParams.get('category') === 'workflow' ? 'emptyTip' : 'emptyTipCustom'}`)} +
) } diff --git a/web/app/components/workflow/block-selector/all-tools.tsx b/web/app/components/workflow/block-selector/all-tools.tsx index dc15313216..aaa3811251 100644 --- a/web/app/components/workflow/block-selector/all-tools.tsx +++ b/web/app/components/workflow/block-selector/all-tools.tsx @@ -43,23 +43,23 @@ const AllTools = ({ return mergedTools.filter((toolWithProvider) => { return isMatchingKeywords(toolWithProvider.name, searchText) - || toolWithProvider.tools.some((tool) => { - return Object.values(tool.label).some((label) => { - return isMatchingKeywords(label, searchText) + || toolWithProvider.tools.some((tool) => { + return Object.values(tool.label).some((label) => { + return isMatchingKeywords(label, searchText) + }) }) - }) }) }, [activeTab, buildInTools, customTools, workflowTools, searchText]) return (
-
+
{ tabs.map(tab => (
setActiveTab(tab.key)} diff --git a/web/app/components/workflow/block-selector/blocks.tsx b/web/app/components/workflow/block-selector/blocks.tsx index a1bada1a8e..eaaa473f3d 100644 --- a/web/app/components/workflow/block-selector/blocks.tsx +++ b/web/app/components/workflow/block-selector/blocks.tsx @@ -58,7 +58,7 @@ const Blocks = ({ > { classification !== '-' && !!list.length && ( -
+
{t(`workflow.tabs.${classification}`)}
) @@ -68,7 +68,7 @@ const Blocks = ({ -
{block.title}
-
{nodesExtraData[block.type].about}
+
{block.title}
+
{nodesExtraData[block.type].about}
)} >
onSelect(block.type)} > -
{block.title}
+
{block.title}
)) @@ -103,7 +103,7 @@ const Blocks = ({
{ isEmpty && ( -
{t('workflow.tabs.noResult')}
+
{t('workflow.tabs.noResult')}
) } { diff --git a/web/app/components/workflow/block-selector/index-bar.tsx b/web/app/components/workflow/block-selector/index-bar.tsx index 6eab51246d..2a4cbad432 100644 --- a/web/app/components/workflow/block-selector/index-bar.tsx +++ b/web/app/components/workflow/block-selector/index-bar.tsx @@ -47,9 +47,9 @@ const IndexBar: FC = ({ letters, itemRefs }) => { element.scrollIntoView({ behavior: 'smooth' }) } return ( -
+
{letters.map(letter => ( -
handleIndexClick(letter)}> +
handleIndexClick(letter)}> {letter}
))} diff --git a/web/app/components/workflow/block-selector/index.tsx b/web/app/components/workflow/block-selector/index.tsx index 6f05ba16fb..dc93c275f2 100644 --- a/web/app/components/workflow/block-selector/index.tsx +++ b/web/app/components/workflow/block-selector/index.tsx @@ -25,6 +25,7 @@ import Input from '@/app/components/base/input' import { Plus02, } from '@/app/components/base/icons/src/vender/line/general' +import classNames from '@/utils/classnames' type NodeSelectorProps = { open?: boolean @@ -114,19 +115,21 @@ const NodeSelector: FC = ({
- +
) } -
-
e.stopPropagation()}> +
+
e.stopPropagation()}> = ({
e.stopPropagation()}> { !noBlocks && ( -
+
{ tabs.map(tab => (
onActiveTabChange(tab.key)} > diff --git a/web/app/components/workflow/block-selector/tools.tsx b/web/app/components/workflow/block-selector/tools.tsx index a2ae845997..394966fb4f 100644 --- a/web/app/components/workflow/block-selector/tools.tsx +++ b/web/app/components/workflow/block-selector/tools.tsx @@ -45,7 +45,7 @@ const Blocks = ({ -
{tool.label[language]}
-
{tool.description[language]}
+
{tool.label[language]}
+
{tool.description[language]}
)} >
onSelect(BlockEnum.Tool, { provider_id: toolWithProvider.id, provider_type: toolWithProvider.type, @@ -75,7 +75,7 @@ const Blocks = ({ type={BlockEnum.Tool} toolIcon={toolWithProvider.icon} /> -
{tool.label[language]}
+
{tool.label[language]}
)) @@ -100,7 +100,7 @@ const Blocks = ({
{ !tools.length && !showWorkflowEmpty && ( -
{t('workflow.tabs.noResult')}
+
{t('workflow.tabs.noResult')}
) } {!tools.length && showWorkflowEmpty && ( diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index ffa14b347b..d04163b853 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -506,3 +506,5 @@ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE' export const CUSTOM_NODE = 'custom' export const CUSTOM_EDGE = 'custom' export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK' +export const DEFAULT_RETRY_MAX = 3 +export const DEFAULT_RETRY_INTERVAL = 100 diff --git a/web/app/components/workflow/header/editing-title.tsx b/web/app/components/workflow/header/editing-title.tsx index 44a85631dc..9148420cbe 100644 --- a/web/app/components/workflow/header/editing-title.tsx +++ b/web/app/components/workflow/header/editing-title.tsx @@ -13,7 +13,7 @@ const EditingTitle = () => { const isSyncingWorkflowDraft = useStore(s => s.isSyncingWorkflowDraft) return ( -
+
{ !!draftUpdatedAt && ( <> diff --git a/web/app/components/workflow/header/index.tsx b/web/app/components/workflow/header/index.tsx index 010d9ca1cd..6e46990df8 100644 --- a/web/app/components/workflow/header/index.tsx +++ b/web/app/components/workflow/header/index.tsx @@ -27,6 +27,7 @@ import { } from '../hooks' import AppPublisher from '../../app/app-publisher' import { ToastContext } from '../../base/toast' +import Divider from '../../base/divider' import RunAndHistory from './run-and-history' import EditingTitle from './editing-title' import RunningTitle from './running-title' @@ -144,15 +145,12 @@ const Header: FC = () => { return (
{ appSidebarExpand === 'collapse' && ( -
{appDetail?.name}
+
{appDetail?.name}
) } { @@ -171,7 +169,7 @@ const Header: FC = () => { {/* */} {isChatMode && } -
+ -
+
- )} - -
- {isRunning && ( - - )} - {isFinished && ( - <> - {result} - - )} -
+ {isRunning && ( + + )} + {isFinished && ( + <> + {result} + + )} +
+ ) + }
) diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index f11f8bd5fb..89412cabb3 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -14,7 +14,6 @@ import type { CommonNodeType, Node, } from '@/app/components/workflow/types' -import Split from '@/app/components/workflow/nodes/_base/components/split' import Tooltip from '@/app/components/base/tooltip' type ErrorHandleProps = Pick @@ -45,7 +44,6 @@ const ErrorHandle = ({ return ( <> -
{ + const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate() + + const handleRetryConfigChange = useCallback((value?: WorkflowRetryConfig) => { + handleNodeDataUpdateWithSyncDraft({ + id, + data: { + retry_config: value, + }, + }) + }, [id, handleNodeDataUpdateWithSyncDraft]) + + return { + handleRetryConfigChange, + } +} + +export const useRetryDetailShowInSingleRun = () => { + const [retryDetails, setRetryDetails] = useState() + + const handleRetryDetailsChange = useCallback((details: NodeTracing[] | undefined) => { + setRetryDetails(details) + }, []) + + return { + retryDetails, + handleRetryDetailsChange, + } +} diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx new file mode 100644 index 0000000000..34c3e28d2c --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx @@ -0,0 +1,91 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiAlertFill, + RiCheckboxCircleFill, + RiLoader2Line, +} from '@remixicon/react' +import type { Node } from '@/app/components/workflow/types' +import { NodeRunningStatus } from '@/app/components/workflow/types' +import cn from '@/utils/classnames' + +type RetryOnNodeProps = Pick +const RetryOnNode = ({ + data, +}: RetryOnNodeProps) => { + const { t } = useTranslation() + const { retry_config } = data + const showSelectedBorder = data.selected || data._isBundled || data._isEntering + const { + isRunning, + isSuccessful, + isException, + isFailed, + } = useMemo(() => { + return { + isRunning: data._runningStatus === NodeRunningStatus.Running && !showSelectedBorder, + isSuccessful: data._runningStatus === NodeRunningStatus.Succeeded && !showSelectedBorder, + isFailed: data._runningStatus === NodeRunningStatus.Failed && !showSelectedBorder, + isException: data._runningStatus === NodeRunningStatus.Exception && !showSelectedBorder, + } + }, [data._runningStatus, showSelectedBorder]) + const showDefault = !isRunning && !isSuccessful && !isException && !isFailed + + if (!retry_config?.retry_enabled) + return null + + if (!showDefault && !data._retryIndex) + return null + + return ( +
+
+
+ { + showDefault && ( + t('workflow.nodes.common.retry.retryTimes', { times: retry_config.max_retries }) + ) + } + { + isRunning && ( + <> + + {t('workflow.nodes.common.retry.retrying')} + + ) + } + { + isSuccessful && ( + <> + + {t('workflow.nodes.common.retry.retrySuccessful')} + + ) + } + { + (isFailed || isException) && ( + <> + + {t('workflow.nodes.common.retry.retryFailed')} + + ) + } +
+ { + !showDefault && !!data._retryIndex && ( +
+ {data._retryIndex}/{data.retry_config?.max_retries} +
+ ) + } +
+
+ ) +} + +export default RetryOnNode diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx new file mode 100644 index 0000000000..dc877a632c --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx @@ -0,0 +1,117 @@ +import { useTranslation } from 'react-i18next' +import { useRetryConfig } from './hooks' +import s from './style.module.css' +import Switch from '@/app/components/base/switch' +import Slider from '@/app/components/base/slider' +import Input from '@/app/components/base/input' +import type { + Node, +} from '@/app/components/workflow/types' +import Split from '@/app/components/workflow/nodes/_base/components/split' + +type RetryOnPanelProps = Pick +const RetryOnPanel = ({ + id, + data, +}: RetryOnPanelProps) => { + const { t } = useTranslation() + const { handleRetryConfigChange } = useRetryConfig(id) + const { retry_config } = data + + const handleRetryEnabledChange = (value: boolean) => { + handleRetryConfigChange({ + retry_enabled: value, + max_retries: retry_config?.max_retries || 3, + retry_interval: retry_config?.retry_interval || 1000, + }) + } + + const handleMaxRetriesChange = (value: number) => { + if (value > 10) + value = 10 + else if (value < 1) + value = 1 + handleRetryConfigChange({ + retry_enabled: true, + max_retries: value, + retry_interval: retry_config?.retry_interval || 1000, + }) + } + + const handleRetryIntervalChange = (value: number) => { + if (value > 5000) + value = 5000 + else if (value < 100) + value = 100 + handleRetryConfigChange({ + retry_enabled: true, + max_retries: retry_config?.max_retries || 3, + retry_interval: value, + }) + } + + return ( + <> +
+
+
+
{t('workflow.nodes.common.retry.retryOnFailure')}
+
+ handleRetryEnabledChange(v)} + /> +
+ { + retry_config?.retry_enabled && ( +
+
+
{t('workflow.nodes.common.retry.maxRetries')}
+ + handleMaxRetriesChange(e.target.value as any)} + min={1} + max={10} + unit={t('workflow.nodes.common.retry.times') || ''} + className={s.input} + /> +
+
+
{t('workflow.nodes.common.retry.retryInterval')}
+ + handleRetryIntervalChange(e.target.value as any)} + min={100} + max={5000} + unit={t('workflow.nodes.common.retry.ms') || ''} + className={s.input} + /> +
+
+ ) + } +
+ + + ) +} + +export default RetryOnPanel diff --git a/web/app/components/workflow/nodes/_base/components/retry/style.module.css b/web/app/components/workflow/nodes/_base/components/retry/style.module.css new file mode 100644 index 0000000000..2ce8717af8 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/style.module.css @@ -0,0 +1,5 @@ +.input::-webkit-inner-spin-button, +.input::-webkit-outer-spin-button { + -webkit-appearance: none; + margin: 0; +} \ No newline at end of file diff --git a/web/app/components/workflow/nodes/_base/components/retry/types.ts b/web/app/components/workflow/nodes/_base/components/retry/types.ts new file mode 100644 index 0000000000..bb5f593fd5 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/types.ts @@ -0,0 +1,5 @@ +export type WorkflowRetryConfig = { + max_retries: number + retry_interval: number + retry_enabled: boolean +} diff --git a/api/tests/integration_tests/model_runtime/x/__init__.py b/web/app/components/workflow/nodes/_base/components/retry/utils.ts similarity index 100% rename from api/tests/integration_tests/model_runtime/x/__init__.py rename to web/app/components/workflow/nodes/_base/components/retry/utils.ts diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index f2da2da35a..4807fa3b2b 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,7 +25,10 @@ import { useNodesReadOnly, useToolIcon, } from '../../hooks' -import { hasErrorHandleNode } from '../../utils' +import { + hasErrorHandleNode, + hasRetryNode, +} from '../../utils' import { useNodeIterationInteractions } from '../iteration/use-interactions' import type { IterationNodeType } from '../iteration/types' import { @@ -35,6 +38,7 @@ import { import NodeResizer from './components/node-resizer' import NodeControl from './components/node-control' import ErrorHandleOnNode from './components/error-handle/error-handle-on-node' +import RetryOnNode from './components/retry/retry-on-node' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' @@ -237,6 +241,14 @@ const BaseNode: FC = ({
) } + { + hasRetryNode(data.type) && ( + + ) + } { hasErrorHandleNode(data.type) && ( = ({
{cloneElement(children, { id, data })}
+ + { + hasRetryNode(data.type) && ( + + ) + } { hasErrorHandleNode(data.type) && ( = { defaultValue: { @@ -24,6 +27,11 @@ const nodeDefault: NodeDefault = { max_read_timeout: 0, max_write_timeout: 0, }, + retry_config: { + retry_enabled: true, + max_retries: 3, + retry_interval: 100, + }, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode diff --git a/web/app/components/workflow/nodes/http/panel.tsx b/web/app/components/workflow/nodes/http/panel.tsx index 5c613aa0f3..91b3a6140d 100644 --- a/web/app/components/workflow/nodes/http/panel.tsx +++ b/web/app/components/workflow/nodes/http/panel.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import React from 'react' +import { memo } from 'react' import { useTranslation } from 'react-i18next' import useConfig from './use-config' import ApiInput from './components/api-input' @@ -18,6 +18,7 @@ import { FileArrow01 } from '@/app/components/base/icons/src/vender/line/files' import type { NodePanelProps } from '@/app/components/workflow/types' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import ResultPanel from '@/app/components/workflow/run/result-panel' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' const i18nPrefix = 'workflow.nodes.http' @@ -60,6 +61,10 @@ const Panel: FC> = ({ hideCurlPanel, handleCurlImport, } = useConfig(id, data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() // To prevent prompt editor in body not update data. if (!isDataReady) return null @@ -181,6 +186,7 @@ const Panel: FC> = ({ {isShowSingleRun && ( > = ({ runningStatus={runningStatus} onRun={handleRun} onStop={handleStop} - result={} + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )} {(isShowCurlPanel && !readOnly) && ( @@ -207,4 +215,4 @@ const Panel: FC> = ({ ) } -export default React.memo(Panel) +export default memo(Panel) diff --git a/web/app/components/workflow/nodes/iteration/use-config.ts b/web/app/components/workflow/nodes/iteration/use-config.ts index 6fb8797dcd..fd69fecaf0 100644 --- a/web/app/components/workflow/nodes/iteration/use-config.ts +++ b/web/app/components/workflow/nodes/iteration/use-config.ts @@ -52,6 +52,12 @@ const useConfig = (id: string, payload: IterationNodeType) => { [VarType.number]: VarType.arrayNumber, [VarType.object]: VarType.arrayObject, [VarType.file]: VarType.arrayFile, + // list operator node can output array + [VarType.array]: VarType.array, + [VarType.arrayFile]: VarType.arrayFile, + [VarType.arrayString]: VarType.arrayString, + [VarType.arrayNumber]: VarType.arrayNumber, + [VarType.arrayObject]: VarType.arrayObject, } as Record)[outputItemType] || VarType.arrayString }) setInputs(newInputs) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index e9da9acccc..794fcbca4a 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -129,9 +129,6 @@ export const getMultipleRetrievalConfig = ( reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } - if (!rerankModelIsValid) - result.reranking_model = undefined - const setDefaultWeights = () => { result.weights = { vector_setting: { @@ -198,7 +195,6 @@ export const getMultipleRetrievalConfig = ( setDefaultWeights() } } - if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { result.reranking_mode = RerankingModeEnum.WeightedScore setDefaultWeights() diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 21ef6395b1..60f68d93e2 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -19,6 +19,7 @@ import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/c import ResultPanel from '@/app/components/workflow/run/result-panel' import Tooltip from '@/app/components/base/tooltip' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' const i18nPrefix = 'workflow.nodes.llm' @@ -69,6 +70,10 @@ const Panel: FC> = ({ runResult, filterJinjia2InputVar, } = useConfig(id, data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() const model = inputs.model @@ -282,12 +287,15 @@ const Panel: FC> = ({ {isShowSingleRun && ( } + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )}
diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index 5d29b767ad..bab7c20d5b 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -61,20 +61,12 @@ const InputVarList: FC = ({ const newValue = produce(value, (draft: ToolVarInputs) => { const target = draft[variable] if (target) { - if (!isSupportConstantValue || varKindType === VarKindType.variable) { - if (isSupportConstantValue) - target.type = VarKindType.variable - - target.value = varValue as ValueSelector - } - else { - target.type = VarKindType.constant - target.value = varValue as string - } + target.type = varKindType + target.value = varValue } else { draft[variable] = { - type: VarKindType.variable, + type: varKindType, value: varValue, } } @@ -170,10 +162,10 @@ const InputVarList: FC = ({ readonly={readOnly} isShowNodeName nodeId={nodeId} - value={varInput?.type === VarKindType.constant ? (varInput?.value || '') : (varInput?.value || [])} + value={varInput?.type === VarKindType.constant ? (varInput?.value ?? '') : (varInput?.value ?? [])} onChange={handleNotMixedTypeChange(variable)} onOpen={handleOpen(index)} - defaultVarKindType={isNumber ? VarKindType.constant : VarKindType.variable} + defaultVarKindType={varInput?.type || (isNumber ? VarKindType.constant : VarKindType.variable)} isSupportConstantValue={isSupportConstantValue} filterVar={isNumber ? filterVar : undefined} availableVars={isSelect ? availableVars : undefined} diff --git a/web/app/components/workflow/nodes/tool/panel.tsx b/web/app/components/workflow/nodes/tool/panel.tsx index 49e645faa4..d0d4c3a839 100644 --- a/web/app/components/workflow/nodes/tool/panel.tsx +++ b/web/app/components/workflow/nodes/tool/panel.tsx @@ -14,6 +14,8 @@ import Loading from '@/app/components/base/loading' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import ResultPanel from '@/app/components/workflow/run/result-panel' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' +import { useToolIcon } from '@/app/components/workflow/hooks' const i18nPrefix = 'workflow.nodes.tool' @@ -48,6 +50,11 @@ const Panel: FC> = ({ handleStop, runResult, } = useConfig(id, data) + const toolIcon = useToolIcon(data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() if (isLoading) { return
@@ -143,12 +150,16 @@ const Panel: FC> = ({ {isShowSingleRun && ( } + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )}
diff --git a/web/app/components/workflow/nodes/tool/use-config.ts b/web/app/components/workflow/nodes/tool/use-config.ts index df8ad47985..94046ba4fa 100644 --- a/web/app/components/workflow/nodes/tool/use-config.ts +++ b/web/app/components/workflow/nodes/tool/use-config.ts @@ -132,7 +132,7 @@ const useConfig = (id: string, payload: ToolNodeType) => { draft.tool_parameters = {} }) setInputs(inputsWithDefaultValue) - // eslint-disable-next-line react-hooks/exhaustive-deps + // eslint-disable-next-line react-hooks/exhaustive-deps }, [currTool]) // setting when call @@ -214,8 +214,13 @@ const useConfig = (id: string, payload: ToolNodeType) => { .map(k => inputs.tool_parameters[k]) const varInputs = getInputVars(hadVarParams.map((p) => { - if (p.type === VarType.variable) + if (p.type === VarType.variable) { + // handle the old wrong value not crash the page + if (!(p.value as any).join) + return `{{#${p.value}#}}` + return `{{#${(p.value as ValueSelector).join('.')}#}}` + } return p.value as string })) diff --git a/web/app/components/workflow/operator/add-block.tsx b/web/app/components/workflow/operator/add-block.tsx index 388fbc053f..32f0007293 100644 --- a/web/app/components/workflow/operator/add-block.tsx +++ b/web/app/components/workflow/operator/add-block.tsx @@ -78,9 +78,9 @@ const AddBlock = ({ title={t('workflow.common.addBlock')} >
diff --git a/web/app/components/workflow/operator/control.tsx b/web/app/components/workflow/operator/control.tsx index 7c67b70816..cd18def056 100644 --- a/web/app/components/workflow/operator/control.tsx +++ b/web/app/components/workflow/operator/control.tsx @@ -18,6 +18,7 @@ import { ControlMode, } from '../types' import { useStore } from '../store' +import Divider from '../../base/divider' import AddBlock from './add-block' import TipPopup from './tip-popup' import { useOperator } from './hooks' @@ -43,26 +44,26 @@ const Control = () => { } return ( -
+
-
+
@@ -73,20 +74,20 @@ const Control = () => {
-
+
diff --git a/web/app/components/workflow/operator/index.tsx b/web/app/components/workflow/operator/index.tsx index 043bd60aae..80c2bb5306 100644 --- a/web/app/components/workflow/operator/index.tsx +++ b/web/app/components/workflow/operator/index.tsx @@ -17,7 +17,9 @@ const Operator = ({ handleUndo, handleRedo }: OperatorProps) => { width: 102, height: 72, }} - className='!absolute !left-4 !bottom-14 z-[9] !m-0 !w-[102px] !h-[72px] !border-[0.5px] !border-black/8 !rounded-lg !shadow-lg' + maskColor='var(--color-shadow-shadow-5)' + className='!absolute !left-4 !bottom-14 z-[9] !m-0 !w-[102px] !h-[72px] !border-[0.5px] !border-divider-subtle + !rounded-lg !shadow-md !shadow-shadow-shadow-5 !bg-workflow-minimap-bg' />
diff --git a/web/app/components/workflow/operator/tip-popup.tsx b/web/app/components/workflow/operator/tip-popup.tsx index a389d9e4c6..85e9a50a51 100644 --- a/web/app/components/workflow/operator/tip-popup.tsx +++ b/web/app/components/workflow/operator/tip-popup.tsx @@ -15,12 +15,12 @@ const TipPopup = ({ return ( - {title} +
+ {title} { - shortcuts && + shortcuts && }
} diff --git a/web/app/components/workflow/operator/zoom-in-out.tsx b/web/app/components/workflow/operator/zoom-in-out.tsx index 654097b430..6c4bed3751 100644 --- a/web/app/components/workflow/operator/zoom-in-out.tsx +++ b/web/app/components/workflow/operator/zoom-in-out.tsx @@ -18,10 +18,9 @@ import { useNodesSyncDraft, useWorkflowReadOnly, } from '../hooks' -import { - getKeyboardKeyNameBySystem, -} from '../utils' + import ShortcutsName from '../shortcuts-name' +import Divider from '../../base/divider' import TipPopup from './tip-popup' import cn from '@/utils/classnames' import { @@ -132,53 +131,54 @@ const ZoomInOut: FC = () => { >
{ e.stopPropagation() zoomOut() }} > - +
-
{parseFloat(`${zoom * 100}`).toFixed(0)}%
+
{parseFloat(`${zoom * 100}`).toFixed(0)}%
{ e.stopPropagation() zoomIn() }} > - +
-
+
{ ZOOM_IN_OUT_OPTIONS.map((options, i) => ( { i !== 0 && ( -
+ ) }
@@ -186,25 +186,27 @@ const ZoomInOut: FC = () => { options.map(option => (
handleZoom(option.key)} > - {option.text} - { - option.key === ZoomType.zoomToFit && ( - - ) - } - { - option.key === ZoomType.zoomTo50 && ( - - ) - } - { - option.key === ZoomType.zoomTo100 && ( - - ) - } + {option.text} +
+ { + option.key === ZoomType.zoomToFit && ( + + ) + } + { + option.key === ZoomType.zoomTo50 && ( + + ) + } + { + option.key === ZoomType.zoomTo100 && ( + + ) + } +
)) } diff --git a/web/app/components/workflow/panel-contextmenu.tsx b/web/app/components/workflow/panel-contextmenu.tsx index f01e3037a2..8ed0e10dca 100644 --- a/web/app/components/workflow/panel-contextmenu.tsx +++ b/web/app/components/workflow/panel-contextmenu.tsx @@ -5,6 +5,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { useClickAway } from 'ahooks' +import Divider from '../base/divider' import ShortcutsName from './shortcuts-name' import { useStore } from './store' import { @@ -41,7 +42,7 @@ const PanelContextmenu = () => { const renderTrigger = () => { return (
{t('workflow.common.addBlock')}
@@ -53,7 +54,7 @@ const PanelContextmenu = () => { return (
{ }} />
{ e.stopPropagation() handleAddNote() @@ -79,7 +80,7 @@ const PanelContextmenu = () => { {t('workflow.nodes.note.addNote')}
{ handleStartWorkflowRun() handlePaneContextmenuCancel() @@ -89,12 +90,12 @@ const PanelContextmenu = () => {
-
+
{ if (clipboardElements.length) { @@ -107,16 +108,16 @@ const PanelContextmenu = () => {
-
+
exportCheck()} > {t('app.export')}
setShowImportDSLModal(true)} > {t('workflow.common.importDSL')} diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 5d932a1ba2..ebd5e7a99d 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -27,6 +27,7 @@ import { getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' import type { FileEntity } from '@/app/components/base/file-uploader/types' +import type { NodeTracing } from '@/types/workflow' type GetAbortController = (abortController: AbortController) => void type SendCallback = { @@ -381,6 +382,28 @@ export const useChat = ( } })) }, + onNodeRetry: ({ data }) => { + if (data.iteration_id) + return + + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id) + }) + if (responseItem.workflowProcess!.tracing[currentIndex].retryDetail) + responseItem.workflowProcess!.tracing[currentIndex].retryDetail?.push(data as NodeTracing) + else + responseItem.workflowProcess!.tracing[currentIndex].retryDetail = [data as NodeTracing] + + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) + }, onNodeFinished: ({ data }) => { if (data.iteration_id) return @@ -394,6 +417,9 @@ export const useChat = ( ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } : {}), + ...(responseItem.workflowProcess!.tracing[currentIndex]?.retryDetail + ? { retryDetail: responseItem.workflowProcess!.tracing[currentIndex].retryDetail } + : {}), ...data, } as any handleUpdateChatList(produce(chatListRef.current, (draft) => { diff --git a/web/app/components/workflow/panel/workflow-preview.tsx b/web/app/components/workflow/panel/workflow-preview.tsx index 2139ebd338..210a95f1f8 100644 --- a/web/app/components/workflow/panel/workflow-preview.tsx +++ b/web/app/components/workflow/panel/workflow-preview.tsx @@ -25,6 +25,7 @@ import { import { SimpleBtn } from '../../app/text-generate/item' import Toast from '../../base/toast' import IterationResultPanel from '../run/iteration-result-panel' +import RetryResultPanel from '../run/retry-result-panel' import InputsPanel from './inputs-panel' import cn from '@/utils/classnames' import Loading from '@/app/components/base/loading' @@ -53,11 +54,16 @@ const WorkflowPreview = () => { }, [workflowRunningData]) const [iterationRunResult, setIterationRunResult] = useState([]) + const [retryRunResult, setRetryRunResult] = useState([]) const [iterDurationMap, setIterDurationMap] = useState({}) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) + const [isShowRetryDetail, { + setTrue: doShowRetryDetail, + setFalse: doHideRetryDetail, + }] = useBoolean(false) const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterationDurationMap: IterationDurationMap) => { setIterDurationMap(iterationDurationMap) @@ -65,6 +71,11 @@ const WorkflowPreview = () => { doShowIterationDetail() }, [doShowIterationDetail]) + const handleRetryDetail = useCallback((detail: NodeTracing[]) => { + setRetryRunResult(detail) + doShowRetryDetail() + }, [doShowRetryDetail]) + if (isShowIterationDetail) { return (
{
)} - {currentTab === 'TRACING' && ( + {currentTab === 'TRACING' && !isShowRetryDetail && ( )} {currentTab === 'TRACING' && !workflowRunningData?.tracing?.length && ( @@ -213,7 +225,14 @@ const WorkflowPreview = () => {
)} - + { + currentTab === 'TRACING' && isShowRetryDetail && ( + + ) + }
)} diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 2bf705f4ce..8b0319cabe 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -9,6 +9,7 @@ import OutputPanel from './output-panel' import ResultPanel from './result-panel' import TracingPanel from './tracing-panel' import IterationResultPanel from './iteration-result-panel' +import RetryResultPanel from './retry-result-panel' import cn from '@/utils/classnames' import { ToastContext } from '@/app/components/base/toast' import Loading from '@/app/components/base/loading' @@ -77,11 +78,24 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const groupMap = nodeGroupMap.get(iterationNode.node_id)! - if (!groupMap.has(runId)) + if (!groupMap.has(runId)) { groupMap.set(runId, [item]) + } + else { + if (item.status === 'retry') { + const retryNode = groupMap.get(runId)!.find(node => node.node_id === item.node_id) - else - groupMap.get(runId)!.push(item) + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + } + else { + groupMap.get(runId)!.push(item) + } + } if (item.status === 'failed') { iterationNode.status = 'failed' @@ -93,10 +107,24 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { const { details } = iterationNode if (details) { - if (!details[index]) + if (!details[index]) { details[index] = [item] - else - details[index].push(item) + } + else { + if (item.status === 'retry') { + const retryNode = details[index].find(node => node.node_id === item.node_id) + + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + } + else { + details[index].push(item) + } + } } if (item.status === 'failed') { @@ -107,6 +135,18 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const processNonIterationNode = (item: NodeTracing) => { const { execution_metadata } = item if (!execution_metadata?.iteration_id) { + if (item.status === 'retry') { + const retryNode = result.find(node => node.node_id === item.node_id) + + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + + return + } result.push(item) return } @@ -181,10 +221,15 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const [iterationRunResult, setIterationRunResult] = useState([]) const [iterDurationMap, setIterDurationMap] = useState({}) + const [retryRunResult, setRetryRunResult] = useState([]) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) + const [isShowRetryDetail, { + setTrue: doShowRetryDetail, + setFalse: doHideRetryDetail, + }] = useBoolean(false) const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => { setIterationRunResult(detail) @@ -192,6 +237,11 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe setIterDurationMap(iterDurationMap) }, [doShowIterationDetail, setIterationRunResult, setIterDurationMap]) + const handleShowRetryDetail = useCallback((detail: NodeTracing[]) => { + setRetryRunResult(detail) + doShowRetryDetail() + }, [doShowRetryDetail, setRetryRunResult]) + if (isShowIterationDetail) { return (
@@ -261,13 +311,22 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe exceptionCounts={runDetail.exceptions_count} /> )} - {!loading && currentTab === 'TRACING' && ( + {!loading && currentTab === 'TRACING' && !isShowRetryDetail && ( )} + { + !loading && currentTab === 'TRACING' && isShowRetryDetail && ( + + ) + }
) diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index b13eadec99..b809e1e669 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -11,6 +11,7 @@ import { import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' import { NodeRunningStatus } from '../types' import TracingPanel from './tracing-panel' +import RetryResultPanel from './retry-result-panel' import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' import type { IterationDurationMap, NodeTracing } from '@/types/workflow' @@ -41,8 +42,8 @@ const IterationResultPanel: FC = ({ })) }, []) const countIterDuration = (iteration: NodeTracing[], iterDurationMap: IterationDurationMap): string => { - const IterRunIndex = iteration[0].execution_metadata.iteration_index as number - const iterRunId = iteration[0].execution_metadata.parallel_mode_run_id + const IterRunIndex = iteration[0]?.execution_metadata?.iteration_index as number + const iterRunId = iteration[0]?.execution_metadata?.parallel_mode_run_id const iterItem = iterDurationMap[iterRunId || IterRunIndex] const duration = iterItem return `${(duration && duration > 0.01) ? duration.toFixed(2) : 0.01}s` @@ -74,6 +75,10 @@ const IterationResultPanel: FC = ({ ) } + const [retryRunResult, setRetryRunResult] = useState | undefined>() + const handleRetryDetail = (v: number, detail?: NodeTracing[]) => { + setRetryRunResult({ ...retryRunResult, [v]: detail }) + } const main = ( <> @@ -116,15 +121,28 @@ const IterationResultPanel: FC = ({ {expandedIterations[index] &&
} -
- -
+ { + !retryRunResult?.[index] && ( +
+ handleRetryDetail(index, v)} + /> +
+ ) + } + { + retryRunResult?.[index] && ( + handleRetryDetail(index, undefined)} + /> + ) + }
))}
diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index d1a02ecfe0..d2da319a02 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -8,6 +8,7 @@ import { RiCheckboxCircleFill, RiErrorWarningLine, RiLoader2Line, + RiRestartFill, } from '@remixicon/react' import BlockIcon from '../block-icon' import { BlockEnum } from '../types' @@ -20,6 +21,7 @@ import Button from '@/app/components/base/button' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import type { IterationDurationMap, NodeTracing } from '@/types/workflow' import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip' +import { hasRetryNode } from '@/app/components/workflow/utils' type Props = { className?: string @@ -28,8 +30,10 @@ type Props = { hideInfo?: boolean hideProcessDetail?: boolean onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void + onShowRetryDetail?: (detail: NodeTracing[]) => void notShowIterationNav?: boolean justShowIterationNavArrow?: boolean + justShowRetryNavArrow?: boolean } const NodePanel: FC = ({ @@ -39,6 +43,7 @@ const NodePanel: FC = ({ hideInfo = false, hideProcessDetail, onShowIterationDetail, + onShowRetryDetail, notShowIterationNav, justShowIterationNavArrow, }) => { @@ -88,11 +93,17 @@ const NodePanel: FC = ({ }, [nodeInfo.expand, setCollapseState]) const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration + const isRetryNode = hasRetryNode(nodeInfo.node_type) && nodeInfo.retryDetail const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() onShowIterationDetail?.(nodeInfo.details || [], nodeInfo?.iterDurationMap || nodeInfo.execution_metadata?.iteration_duration_map || {}) } + const handleOnShowRetryDetail = (e: React.MouseEvent) => { + e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() + onShowRetryDetail?.(nodeInfo.retryDetail || []) + } return (
@@ -169,6 +180,19 @@ const NodePanel: FC = ({
)} + {isRetryNode && ( + + )}
{(nodeInfo.status === 'stopped') && ( @@ -192,6 +216,11 @@ const NodePanel: FC = ({ {nodeInfo.error} )} + {nodeInfo.status === 'retry' && ( + + {nodeInfo.error} + + )}
{nodeInfo.inputs && (
diff --git a/web/app/components/workflow/run/result-panel.tsx b/web/app/components/workflow/run/result-panel.tsx index a688693e4f..bbe740ad48 100644 --- a/web/app/components/workflow/run/result-panel.tsx +++ b/web/app/components/workflow/run/result-panel.tsx @@ -1,11 +1,17 @@ 'use client' import type { FC } from 'react' import { useTranslation } from 'react-i18next' +import { + RiArrowRightSLine, + RiRestartFill, +} from '@remixicon/react' import StatusPanel from './status' import MetaData from './meta' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip' +import type { NodeTracing } from '@/types/workflow' +import Button from '@/app/components/base/button' type ResultPanelProps = { inputs?: string @@ -22,6 +28,8 @@ type ResultPanelProps = { showSteps?: boolean exceptionCounts?: number execution_metadata?: any + retry_events?: NodeTracing[] + onShowRetryDetail?: (retries: NodeTracing[]) => void } const ResultPanel: FC = ({ @@ -38,8 +46,11 @@ const ResultPanel: FC = ({ showSteps, exceptionCounts, execution_metadata, + retry_events, + onShowRetryDetail, }) => { const { t } = useTranslation() + return (
@@ -51,6 +62,23 @@ const ResultPanel: FC = ({ exceptionCounts={exceptionCounts} />
+ { + retry_events?.length && onShowRetryDetail && ( +
+ +
+ ) + }
void +} + +const RetryResultPanel: FC = ({ + list, + onBack, +}) => { + const { t } = useTranslation() + + return ( +
+
{ + e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() + onBack() + }} + > + + {t('workflow.singleRun.back')} +
+ ({ + ...item, + title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`, + }))} + className='bg-background-section-burn' + /> +
+ ) +} +export default memo(RetryResultPanel) diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index 57b3a5cf5f..ad78971895 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -21,6 +21,7 @@ import type { IterationDurationMap, NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void + onShowRetryDetail?: (detail: NodeTracing[]) => void className?: string hideNodeInfo?: boolean hideNodeProcessDetail?: boolean @@ -160,6 +161,7 @@ function buildLogTree(nodes: NodeTracing[], t: (key: string) => string): Tracing const TracingPanel: FC = ({ list, onShowIterationDetail, + onShowRetryDetail, className, hideNodeInfo = false, hideNodeProcessDetail = false, @@ -251,7 +253,9 @@ const TracingPanel: FC = ({ diff --git a/web/app/components/workflow/shortcuts-name.tsx b/web/app/components/workflow/shortcuts-name.tsx index 129753c198..cfb5c33daf 100644 --- a/web/app/components/workflow/shortcuts-name.tsx +++ b/web/app/components/workflow/shortcuts-name.tsx @@ -12,14 +12,14 @@ const ShortcutsName = ({ }: ShortcutsNameProps) => { return (
{ keys.map(key => (
{getKeyboardKeyNameBySystem(key)}
diff --git a/web/app/components/workflow/style.css b/web/app/components/workflow/style.css index ca1d24a52e..253d6b7dd0 100644 --- a/web/app/components/workflow/style.css +++ b/web/app/components/workflow/style.css @@ -19,4 +19,6 @@ #workflow-container .react-flow__node-custom-note { z-index: -1000 !important; -} \ No newline at end of file +} + +#workflow-container .react-flow {} \ No newline at end of file diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index c40ea0de55..6d0fabd90e 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -13,6 +13,7 @@ import type { DefaultValueForm, ErrorHandleTypeEnum, } from '@/app/components/workflow/nodes/_base/components/error-handle/types' +import type { WorkflowRetryConfig } from '@/app/components/workflow/nodes/_base/components/retry/types' export enum BlockEnum { Start = 'start', @@ -68,6 +69,7 @@ export type CommonNodeType = { _iterationIndex?: number _inParallelHovering?: boolean _waitingRun?: boolean + _retryIndex?: number isInIteration?: boolean iteration_id?: string selected?: boolean @@ -77,6 +79,7 @@ export type CommonNodeType = { width?: number height?: number error_strategy?: ErrorHandleTypeEnum + retry_config?: WorkflowRetryConfig default_value?: DefaultValueForm[] } & T & Partial> @@ -293,6 +296,7 @@ export enum NodeRunningStatus { Succeeded = 'succeeded', Failed = 'failed', Exception = 'exception', + Retry = 'retry', } export type OnNodeAdd = ( diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index f5c112d6e8..4c61267e4c 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -26,6 +26,8 @@ import { } from './types' import { CUSTOM_NODE, + DEFAULT_RETRY_INTERVAL, + DEFAULT_RETRY_MAX, ITERATION_CHILDREN_Z_INDEX, ITERATION_NODE_Z_INDEX, NODE_WIDTH_X_OFFSET, @@ -279,6 +281,14 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated } + if (node.data.type === BlockEnum.HttpRequest && !node.data.retry_config) { + node.data.retry_config = { + retry_enabled: true, + max_retries: DEFAULT_RETRY_MAX, + retry_interval: DEFAULT_RETRY_INTERVAL, + } + } + return node }) } @@ -549,6 +559,7 @@ export const isMac = () => { const specialKeysNameMap: Record = { ctrl: '⌘', alt: '⌥', + shift: '⇧', } export const getKeyboardKeyNameBySystem = (key: string) => { @@ -796,3 +807,7 @@ export const isExceptionVariable = (variable: string, nodeType?: BlockEnum) => { return false } + +export const hasRetryNode = (nodeType?: BlockEnum) => { + return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code +} diff --git a/web/app/layout.tsx b/web/app/layout.tsx index 48e35c50e0..9c4f8ba16d 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -34,7 +34,7 @@ const LocaleLayout = ({ diff --git a/web/app/signin/normalForm.tsx b/web/app/signin/normalForm.tsx index 783d8ac507..1911fa35c6 100644 --- a/web/app/signin/normalForm.tsx +++ b/web/app/signin/normalForm.tsx @@ -163,7 +163,7 @@ const NormalForm = () => {
} } {systemFeatures.enable_email_password_login && authType === 'password' && <> - + {systemFeatures.enable_email_code_login &&
{ updateAuthType('code') }}> {t('login.useVerificationCode')}
} diff --git a/web/app/styles/globals.css b/web/app/styles/globals.css index 18bf39345d..f2aadc5820 100644 --- a/web/app/styles/globals.css +++ b/web/app/styles/globals.css @@ -7,6 +7,14 @@ @import "../../themes/manual-light.css"; @import "../../themes/manual-dark.css"; +html { + color-scheme: light; +} + +html[data-theme='dark'] { + color-scheme: dark; +} + html[data-changing-theme] * { transition: none !important; } diff --git a/web/context/app-context.tsx b/web/context/app-context.tsx index 369fe5af19..7addfb83d4 100644 --- a/web/context/app-context.tsx +++ b/web/context/app-context.tsx @@ -31,6 +31,7 @@ export type AppContextValue = { pageContainerRef: React.RefObject langeniusVersionInfo: LangGeniusVersionResponse useSelector: typeof useSelector + isLoadingCurrentWorkspace: boolean } const initialLangeniusVersionInfo = { @@ -77,6 +78,7 @@ const AppContext = createContext({ pageContainerRef: createRef(), langeniusVersionInfo: initialLangeniusVersionInfo, useSelector, + isLoadingCurrentWorkspace: false, }) export function useSelector(selector: (value: AppContextValue) => T): T { @@ -92,7 +94,7 @@ export const AppContextProvider: FC = ({ children }) => const { data: appList, mutate: mutateApps } = useSWR({ url: '/apps', params: { page: 1, limit: 30, name: '' } }, fetchAppList) const { data: userProfileResponse, mutate: mutateUserProfile } = useSWR({ url: '/account/profile', params: {} }, fetchUserProfile) - const { data: currentWorkspaceResponse, mutate: mutateCurrentWorkspace } = useSWR({ url: '/workspaces/current', params: {} }, fetchCurrentWorkspace) + const { data: currentWorkspaceResponse, mutate: mutateCurrentWorkspace, isLoading: isLoadingCurrentWorkspace } = useSWR({ url: '/workspaces/current', params: {} }, fetchCurrentWorkspace) const { data: systemFeatures } = useSWR({ url: '/console/system-features' }, getSystemFeatures, { fallbackData: defaultSystemFeatures, @@ -157,6 +159,7 @@ export const AppContextProvider: FC = ({ children }) => isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, mutateCurrentWorkspace, + isLoadingCurrentWorkspace, }}>
{globalThis.document?.body?.getAttribute('data-public-maintenance-notice') && } diff --git a/web/i18n/de-DE/common.ts b/web/i18n/de-DE/common.ts index 1d7ca955fa..f438b4f018 100644 --- a/web/i18n/de-DE/common.ts +++ b/web/i18n/de-DE/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Kann Apps erstellen & bearbeiten', inviteTeamMember: 'Teammitglied hinzufügen', inviteTeamMemberTip: 'Sie können direkt nach der Anmeldung auf Ihre Teamdaten zugreifen.', + emailNotSetup: 'E-Mail-Server ist nicht eingerichtet, daher können keine Einladungs-E-Mails versendet werden. Bitte informieren Sie die Benutzer über den Einladungslink, der nach der Einladung ausgestellt wird.', email: 'E-Mail', emailInvalid: 'Ungültiges E-Mail-Format', emailPlaceholder: 'Bitte E-Mails eingeben', diff --git a/web/i18n/de-DE/tools.ts b/web/i18n/de-DE/tools.ts index 3be01b8350..2448b3ed8f 100644 --- a/web/i18n/de-DE/tools.ts +++ b/web/i18n/de-DE/tools.ts @@ -144,6 +144,8 @@ const translation = { emptyTitle: 'Kein Workflow-Tool verfügbar', type: 'Art', emptyTip: 'Gehen Sie zu "Workflow -> Als Tool veröffentlichen"', + emptyTitleCustom: 'Kein benutzerdefiniertes Tool verfügbar', + emptyTipCustom: 'Erstellen eines benutzerdefinierten Werkzeugs', }, toolNameUsageTip: 'Name des Tool-Aufrufs für die Argumentation und Aufforderung des Agenten', customToolTip: 'Erfahren Sie mehr über benutzerdefinierte Dify-Tools', diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index 8888e23739..38686f8c1d 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Fehlerbehandlung', tip: 'Ausnahmebehandlungsstrategie, die ausgelöst wird, wenn ein Knoten auf eine Ausnahme stößt.', }, + retry: { + retry: 'Wiederholen', + retryOnFailure: 'Wiederholen bei Fehler', + maxRetries: 'Max. Wiederholungen', + retryInterval: 'Wiederholungsintervall', + retryTimes: 'Wiederholen Sie {{times}} mal bei einem Fehler', + retrying: 'Wiederholung...', + retrySuccessful: 'Wiederholen erfolgreich', + retryFailed: 'Wiederholung fehlgeschlagen', + retryFailedTimes: '{{times}} fehlgeschlagene Wiederholungen', + times: 'mal', + ms: 'Frau', + retries: '{{num}} Wiederholungen', + }, }, start: { required: 'erforderlich', diff --git a/web/i18n/en-US/app.ts b/web/i18n/en-US/app.ts index 546bb08873..2d3ad99cae 100644 --- a/web/i18n/en-US/app.ts +++ b/web/i18n/en-US/app.ts @@ -125,7 +125,7 @@ const translation = { switchStart: 'Start switch', openInExplore: 'Open in Explore', typeSelector: { - all: 'ALL Types', + all: 'All Types ', chatbot: 'Chatbot', agent: 'Agent', workflow: 'Workflow', diff --git a/web/i18n/en-US/common.ts b/web/i18n/en-US/common.ts index ce341a9148..ea0e4a88aa 100644 --- a/web/i18n/en-US/common.ts +++ b/web/i18n/en-US/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Only can manage the knowledge base', inviteTeamMember: 'Add team member', inviteTeamMemberTip: 'They can access your team data directly after signing in.', + emailNotSetup: 'Email server is not set up, so invitation emails cannot be sent. Please notify users of the invitation link that will be issued after invitation instead.', email: 'Email', emailInvalid: 'Invalid Email Format', emailPlaceholder: 'Please input emails', diff --git a/web/i18n/en-US/tools.ts b/web/i18n/en-US/tools.ts index f96ae8144e..b1f278f9ce 100644 --- a/web/i18n/en-US/tools.ts +++ b/web/i18n/en-US/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Manage in Tools', emptyTitle: 'No workflow tool available', emptyTip: 'Go to "Workflow -> Publish as Tool"', + emptyTitleCustom: 'No custom tool available', + emptyTipCustom: 'Create a custom tool', }, createTool: { title: 'Create Custom Tool', diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index e2a2fdb59d..fab25fa509 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -329,6 +329,20 @@ const translation = { tip: 'There are {{num}} nodes in the process running abnormally, please go to tracing to check the logs.', }, }, + retry: { + retry: 'Retry', + retryOnFailure: 'retry on failure', + maxRetries: 'max retries', + retryInterval: 'retry interval', + retryTimes: 'Retry {{times}} times on failure', + retrying: 'Retrying...', + retrySuccessful: 'Retry successful', + retryFailed: 'Retry failed', + retryFailedTimes: '{{times}} retries failed', + times: 'times', + ms: 'ms', + retries: '{{num}} Retries', + }, }, start: { required: 'required', diff --git a/web/i18n/es-ES/common.ts b/web/i18n/es-ES/common.ts index cc9fb47329..2540632758 100644 --- a/web/i18n/es-ES/common.ts +++ b/web/i18n/es-ES/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Solo puede administrar la base de conocimiento', inviteTeamMember: 'Agregar miembro del equipo', inviteTeamMemberTip: 'Pueden acceder a tus datos del equipo directamente después de iniciar sesión.', + emailNotSetup: 'El servidor de correo no está configurado, por lo que no se pueden enviar correos de invitación. En su lugar, notifique a los usuarios el enlace de invitación que se emitirá después de la invitación.', email: 'Correo electrónico', emailInvalid: 'Formato de correo electrónico inválido', emailPlaceholder: 'Por favor ingresa correos electrónicos', diff --git a/web/i18n/es-ES/tools.ts b/web/i18n/es-ES/tools.ts index 546591f1aa..08c9f2026d 100644 --- a/web/i18n/es-ES/tools.ts +++ b/web/i18n/es-ES/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Administrar en Herramientas', emptyTitle: 'No hay herramientas de flujo de trabajo disponibles', emptyTip: 'Ir a "Flujo de Trabajo -> Publicar como Herramienta"', + emptyTitleCustom: 'No hay herramienta personalizada disponible', + emptyTipCustom: 'Crear una herramienta personalizada', }, createTool: { title: 'Crear Herramienta Personalizada', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index c49c611da8..d112ad97b6 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Manejo de errores', tip: 'Estrategia de control de excepciones, que se desencadena cuando un nodo encuentra una excepción.', }, + retry: { + retryOnFailure: 'Volver a intentarlo en caso de error', + maxRetries: 'Número máximo de reintentos', + retryInterval: 'Intervalo de reintento', + retryTimes: 'Reintentar {{times}} veces en caso de error', + retrying: 'Reintentando...', + retrySuccessful: 'Volver a intentarlo correctamente', + retryFailed: 'Error en el reintento', + retryFailedTimes: '{{veces}} reintentos fallidos', + times: 'veces', + ms: 'Sra.', + retries: '{{num}} Reintentos', + retry: 'Reintentar', + }, }, start: { required: 'requerido', diff --git a/web/i18n/fa-IR/common.ts b/web/i18n/fa-IR/common.ts index 2da6cdee8b..deab852ddb 100644 --- a/web/i18n/fa-IR/common.ts +++ b/web/i18n/fa-IR/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'فقط می‌تواند پایگاه دانش را مدیریت کند', inviteTeamMember: 'افزودن عضو تیم', inviteTeamMemberTip: 'آنها می‌توانند پس از ورود به سیستم، مستقیماً به داده‌های تیم شما دسترسی پیدا کنند.', + emailNotSetup: 'سرور ایمیل راه‌اندازی نشده است، بنابراین ایمیل‌های دعوت نمی‌توانند ارسال شوند. لطفاً کاربران را از لینک دعوت که پس از دعوت صادر خواهد شد مطلع کنید。', email: 'ایمیل', emailInvalid: 'فرمت ایمیل نامعتبر است', emailPlaceholder: 'لطفاً ایمیل‌ها را وارد کنید', diff --git a/web/i18n/fa-IR/tools.ts b/web/i18n/fa-IR/tools.ts index 002f55d1d4..60a89d0f32 100644 --- a/web/i18n/fa-IR/tools.ts +++ b/web/i18n/fa-IR/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'مدیریت در ابزارها', emptyTitle: 'هیچ ابزار جریان کاری در دسترس نیست', emptyTip: 'به "جریان کاری -> انتشار به عنوان ابزار" بروید', + emptyTipCustom: 'ایجاد یک ابزار سفارشی', + emptyTitleCustom: 'هیچ ابزار سفارشی در دسترس نیست', }, createTool: { title: 'ایجاد ابزار سفارشی', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index c29f911556..37cba2f16b 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'مدیریت خطا', tip: 'استراتژی مدیریت استثنا، زمانی که یک گره با یک استثنا مواجه می شود، فعال می شود.', }, + retry: { + times: 'بار', + retryInterval: 'فاصله تلاش مجدد', + retryOnFailure: 'در مورد شکست دوباره امتحان کنید', + ms: 'خانم', + retry: 'دوباره', + retries: '{{عدد}} تلاش های مجدد', + maxRetries: 'حداکثر تلاش مجدد', + retrying: 'تلاش مجدد...', + retryFailed: 'تلاش مجدد ناموفق بود', + retryTimes: '{{times}} بار در صورت شکست دوباره امتحان کنید', + retrySuccessful: 'امتحان مجدد با موفقیت انجام دهید', + retryFailedTimes: '{{بار}} تلاش های مجدد ناموفق بود', + }, }, start: { required: 'الزامی', diff --git a/web/i18n/fr-FR/common.ts b/web/i18n/fr-FR/common.ts index 326572916c..25142c11cc 100644 --- a/web/i18n/fr-FR/common.ts +++ b/web/i18n/fr-FR/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Peut construire des applications, mais ne peut pas gérer les paramètres de l\'équipe', inviteTeamMember: 'Ajouter un membre de l\'équipe', inviteTeamMemberTip: 'Ils peuvent accéder directement à vos données d\'équipe après s\'être connectés.', + emailNotSetup: 'Le serveur de messagerie n\'est pas configuré, les e-mails d\'invitation ne peuvent donc pas être envoyés. Veuillez informer les utilisateurs du lien d\'invitation qui sera émis après l\'invitation.', email: 'Courrier électronique', emailInvalid: 'Format de courriel invalide', emailPlaceholder: 'Veuillez entrer des emails', diff --git a/web/i18n/fr-FR/tools.ts b/web/i18n/fr-FR/tools.ts index 34c71e7764..5a7e47906f 100644 --- a/web/i18n/fr-FR/tools.ts +++ b/web/i18n/fr-FR/tools.ts @@ -144,6 +144,8 @@ const translation = { category: 'catégorie', manageInTools: 'Gérer dans Outils', emptyTip: 'Allez dans « Flux de travail -> Publier en tant qu’outil »', + emptyTitleCustom: 'Aucun outil personnalisé disponible', + emptyTipCustom: 'Créer un outil personnalisé', }, openInStudio: 'Ouvrir dans Studio', customToolTip: 'En savoir plus sur les outils personnalisés Dify', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index a2b2406113..e7d2802cb4 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Gestion des erreurs', tip: 'Stratégie de gestion des exceptions, déclenchée lorsqu’un nœud rencontre une exception.', }, + retry: { + retry: 'Réessayer', + retryOnFailure: 'Réessai en cas d’échec', + maxRetries: 'Nombre maximal de tentatives', + retryInterval: 'intervalle de nouvelle tentative', + retryTimes: 'Réessayez {{times}} fois en cas d’échec', + retrying: 'Réessayer...', + retrySuccessful: 'Réessai réussi', + retryFailed: 'Échec de la nouvelle tentative', + retryFailedTimes: '{{times}} les tentatives ont échoué', + times: 'fois', + ms: 'ms', + retries: '{{num}} Tentatives', + }, }, start: { required: 'requis', diff --git a/web/i18n/hi-IN/common.ts b/web/i18n/hi-IN/common.ts index a2c178cb18..aabcfc86e6 100644 --- a/web/i18n/hi-IN/common.ts +++ b/web/i18n/hi-IN/common.ts @@ -204,6 +204,7 @@ const translation = { inviteTeamMember: 'टीम सदस्य जोड़ें', inviteTeamMemberTip: 'वे साइन इन करने के बाद सीधे आपकी टीम डेटा तक पहुंच सकते हैं।', + emailNotSetup: 'ईमेल सर्वर सेट नहीं है, इसलिए आमंत्रण ईमेल नहीं भेजे जा सकते। कृपया उपयोगकर्ताओं को आमंत्रण के बाद जारी किए जाने वाले आमंत्रण लिंक के बारे में सूचित करें。', email: 'ईमेल', emailInvalid: 'अवैध ईमेल प्रारूप', emailPlaceholder: 'कृपया ईमेल दर्ज करें', diff --git a/web/i18n/hi-IN/tools.ts b/web/i18n/hi-IN/tools.ts index 6b0cccebad..2060682931 100644 --- a/web/i18n/hi-IN/tools.ts +++ b/web/i18n/hi-IN/tools.ts @@ -32,6 +32,8 @@ const translation = { manageInTools: 'उपकरणों में प्रबंधित करें', emptyTitle: 'कोई कार्यप्रवाह उपकरण उपलब्ध नहीं', emptyTip: 'कार्यप्रवाह -> उपकरण के रूप में प्रकाशित पर जाएं', + emptyTipCustom: 'एक कस्टम टूल बनाएं', + emptyTitleCustom: 'कोई कस्टम टूल उपलब्ध नहीं है', }, createTool: { title: 'कस्टम उपकरण बनाएं', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 47589078ce..619abee128 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -334,6 +334,20 @@ const translation = { title: 'त्रुटि हैंडलिंग', tip: 'अपवाद हैंडलिंग रणनीति, ट्रिगर जब एक नोड एक अपवाद का सामना करता है।', }, + retry: { + times: 'गुणा', + ms: 'सुश्री', + retryInterval: 'अंतराल का पुनः प्रयास करें', + retrying: 'पुनर्प्रयास।।।', + retryFailed: 'पुनः प्रयास विफल रहा', + retryFailedTimes: '{{times}} पुनः प्रयास विफल रहे', + retryTimes: 'विफलता पर {{times}} बार पुनः प्रयास करें', + retries: '{{num}} पुनर्प्रयास', + maxRetries: 'अधिकतम पुनः प्रयास करता है', + retrySuccessful: 'पुनः प्रयास सफल', + retry: 'पुनर्प्रयास', + retryOnFailure: 'विफलता पर पुनः प्रयास करें', + }, }, start: { required: 'आवश्यक', diff --git a/web/i18n/it-IT/common.ts b/web/i18n/it-IT/common.ts index 35a01d7114..4cee6dec50 100644 --- a/web/i18n/it-IT/common.ts +++ b/web/i18n/it-IT/common.ts @@ -208,6 +208,7 @@ const translation = { inviteTeamMember: 'Aggiungi membro del team', inviteTeamMemberTip: 'Potranno accedere ai dati del tuo team direttamente dopo aver effettuato l\'accesso.', + emailNotSetup: 'Il server email non è configurato, quindi non è possibile inviare email di invito. Si prega di notificare agli utenti il link di invito che verrà emesso dopo l\'invito.', email: 'Email', emailInvalid: 'Formato Email non valido', emailPlaceholder: 'Per favore inserisci le email', diff --git a/web/i18n/it-IT/tools.ts b/web/i18n/it-IT/tools.ts index 00e7cad58c..f9512fb20d 100644 --- a/web/i18n/it-IT/tools.ts +++ b/web/i18n/it-IT/tools.ts @@ -32,6 +32,8 @@ const translation = { manageInTools: 'Gestisci in Strumenti', emptyTitle: 'Nessun strumento di flusso di lavoro disponibile', emptyTip: 'Vai a `Flusso di lavoro -> Pubblica come Strumento`', + emptyTitleCustom: 'Nessun attrezzo personalizzato disponibile', + emptyTipCustom: 'Creare uno strumento personalizzato', }, createTool: { title: 'Crea Strumento Personalizzato', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index e760074e6a..f4390580d5 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -337,6 +337,20 @@ const translation = { title: 'Gestione degli errori', tip: 'Strategia di gestione delle eccezioni, attivata quando un nodo rileva un\'eccezione.', }, + retry: { + retry: 'Ripetere', + retryOnFailure: 'Riprova in caso di errore', + maxRetries: 'Numero massimo di tentativi', + retryInterval: 'Intervallo tentativi', + retryTimes: 'Riprova {{times}} volte in caso di errore', + retrying: 'Riprovare...', + retryFailedTimes: '{{times}} tentativi falliti', + times: 'tempi', + retries: '{{num}} Tentativi', + retrySuccessful: 'Riprova riuscito', + retryFailed: 'Nuovo tentativo non riuscito', + ms: 'ms', + }, }, start: { required: 'richiesto', diff --git a/web/i18n/ja-JP/common.ts b/web/i18n/ja-JP/common.ts index fa3fb223f4..9c23cb6f16 100644 --- a/web/i18n/ja-JP/common.ts +++ b/web/i18n/ja-JP/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'ナレッジベースのみを管理できる', inviteTeamMember: 'チームメンバーを招待する', inviteTeamMemberTip: '彼らはサインイン後、直接あなた様のチームデータにアクセスできます。', + emailNotSetup: 'メールサーバーがセットアップされていないので、招待メールを送信することはできません。代わりに招待後に発行される招待リンクをユーザーに通知してください。', email: 'メール', emailInvalid: '無効なメール形式', emailPlaceholder: 'メールを入力してください', diff --git a/web/i18n/ja-JP/tools.ts b/web/i18n/ja-JP/tools.ts index 12d3634715..f52f101f52 100644 --- a/web/i18n/ja-JP/tools.ts +++ b/web/i18n/ja-JP/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'ツールリストに移動して管理する', emptyTitle: '利用可能なワークフローツールはありません', emptyTip: '追加するには、「ワークフロー -> ツールとして公開 」に移動する', + emptyTitleCustom: 'カスタムツールはありません', + emptyTipCustom: 'カスタムツールの作成', }, createTool: { title: 'カスタムツールを作成する', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index ebb613d31f..1aa764a19f 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -299,22 +299,22 @@ const translation = { }, errorHandle: { none: { - title: '何一つ', + title: '処理なし', desc: '例外が発生して処理されない場合、ノードは実行を停止します', }, defaultValue: { - title: '既定値', - desc: 'エラーが発生した場合は、静的な出力コンテンツを指定します。', - tip: 'エラーの場合は、以下の値を返します。', + title: 'デフォルト値', + desc: '例外が発生した場合は、デフォルトの出力コンテンツを指定します。', + tip: '例外が発生した場合は、以下の値を返します。', inLog: 'ノード例外、デフォルト値に従って出力します。', output: '出力デフォルト値', }, failBranch: { - title: '失敗ブランチ', - customize: 'キャンバスに移動して、失敗ブランチのロジックをカスタマイズします。', - inLog: 'ノード例外は、失敗したブランチを自動的に実行します。ノード出力は、エラータイプとエラーメッセージを返し、それらをダウンストリームに渡します。', - desc: 'エラーが発生した場合は、例外ブランチを実行します', - customizeTip: 'fail ブランチがアクティブになっても、ノードによってスローされた例外はプロセスを終了させません。代わりに、事前定義された fail ブランチが自動的に実行されるため、エラー メッセージ、レポート、修正、またはスキップ アクションを柔軟に提供できます。', + title: 'エラーブランチ', + customize: 'キャンバスに移動して、エラーブランチのロジックをカスタマイズします。', + inLog: '例外が発生した場合は、エラーしたブランチを自動的に実行します。ノード出力は、エラータイプとエラーメッセージを返し、それらをダウンストリームに渡します。', + desc: '例外が発生した場合は、エラーブランチを実行します', + customizeTip: 'エラーブランチがアクティブになっても、ノードによってスローされた例外はプロセスを終了させません。代わりに、事前定義された エラーブランチが自動的に実行されるため、エラーメッセージ、レポート、修正アクション、またはスキップアクションを柔軟に提供できます。', }, partialSucceeded: { tip: 'プロセスに{{num}}ノードが異常に動作していますので、トレースに移動してログを確認してください。', @@ -322,6 +322,20 @@ const translation = { title: 'エラー処理', tip: 'ノードが例外を検出したときにトリガーされる例外処理戦略。', }, + retry: { + retry: 'リトライ', + retryOnFailure: '失敗時の再試行', + maxRetries: '最大再試行回数', + retryInterval: '再試行間隔', + retrying: '再試行。。。', + retryFailed: '再試行に失敗しました', + times: '倍', + ms: 'さん', + retryTimes: '失敗時に{{times}}回再試行', + retrySuccessful: '再試行に成功しました', + retries: '{{num}} 回の再試行', + retryFailedTimes: '{{times}}回のリトライが失敗しました', + }, }, start: { required: '必須', diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index ce860e000e..a599aa9bd1 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -187,6 +187,7 @@ const translation = { editorTip: '앱 빌드만 가능하고 팀 설정 관리 불가능', inviteTeamMember: '팀 멤버 초대', inviteTeamMemberTip: '로그인 후에 바로 팀 데이터에 액세스할 수 있습니다.', + emailNotSetup: '이메일 서버가 설정되지 않아 초대 이메일을 보낼 수 없습니다. 대신 초대 후 발급되는 초대 링크를 사용자에게 알려주세요.', email: '이메일', emailInvalid: '유효하지 않은 이메일 형식', emailPlaceholder: '이메일 입력', diff --git a/web/i18n/ko-KR/tools.ts b/web/i18n/ko-KR/tools.ts index c896a17a4f..0b9f451784 100644 --- a/web/i18n/ko-KR/tools.ts +++ b/web/i18n/ko-KR/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: '도구에서 관리', emptyTitle: '사용 가능한 워크플로우 도구 없음', emptyTip: '"워크플로우 -> 도구로 등록하기"로 이동', + emptyTipCustom: '사용자 지정 도구 만들기', + emptyTitleCustom: '사용 가능한 사용자 지정 도구가 없습니다.', }, createTool: { title: '커스텀 도구 만들기', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index cc2c1b1a28..4a4d2f9193 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: '오류 처리', tip: '노드에 예외가 발생할 때 트리거되는 예외 처리 전략입니다.', }, + retry: { + retry: '재시도', + retryOnFailure: '실패 시 재시도', + maxRetries: '최대 재시도 횟수', + retryInterval: '재시도 간격', + retryTimes: '실패 시 {{times}}번 재시도', + retrying: '재시도...', + retrySuccessful: '재시도 성공', + retryFailed: '재시도 실패', + retryFailedTimes: '{{times}} 재시도 실패', + times: '배', + ms: '미에스', + retries: '{{숫자}} 재시도', + }, }, start: { required: '필수', diff --git a/web/i18n/pl-PL/common.ts b/web/i18n/pl-PL/common.ts index baaf5292c3..69441dbab3 100644 --- a/web/i18n/pl-PL/common.ts +++ b/web/i18n/pl-PL/common.ts @@ -198,6 +198,7 @@ const translation = { inviteTeamMember: 'Dodaj członka zespołu', inviteTeamMemberTip: 'Mogą uzyskać bezpośredni dostęp do danych Twojego zespołu po zalogowaniu.', + emailNotSetup: 'Serwer poczty nie jest skonfigurowany, więc nie można wysyłać zaproszeń e-mail. Proszę powiadomić użytkowników o linku do zaproszenia, który zostanie wydany po zaproszeniu.', email: 'Email', emailInvalid: 'Nieprawidłowy format e-maila', emailPlaceholder: 'Proszę podać adresy e-mail', diff --git a/web/i18n/pl-PL/tools.ts b/web/i18n/pl-PL/tools.ts index f34825b049..768883522e 100644 --- a/web/i18n/pl-PL/tools.ts +++ b/web/i18n/pl-PL/tools.ts @@ -148,6 +148,8 @@ const translation = { add: 'dodawać', emptyTitle: 'Brak dostępnego narzędzia do przepływu pracy', emptyTip: 'Przejdź do "Przepływ pracy -> Opublikuj jako narzędzie"', + emptyTitleCustom: 'Brak dostępnego narzędzia niestandardowego', + emptyTipCustom: 'Tworzenie narzędzia niestandardowego', }, openInStudio: 'Otwieranie w Studio', customToolTip: 'Dowiedz się więcej o niestandardowych narzędziach Dify', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 2db6cf2bfb..13784df603 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -322,6 +322,20 @@ const translation = { tip: 'Strategia obsługi wyjątków, wyzwalana, gdy węzeł napotka wyjątek.', title: 'Obsługa błędów', }, + retry: { + retry: 'Ponów próbę', + maxRetries: 'Maksymalna liczba ponownych prób', + retryInterval: 'Interwał ponawiania prób', + retryTimes: 'Ponów próbę {{times}} razy w przypadku niepowodzenia', + retrying: 'Ponawianie...', + retrySuccessful: 'Ponawianie próby powiodło się', + retryFailed: 'Ponawianie próby nie powiodło się', + times: 'razy', + retries: '{{liczba}} Ponownych prób', + retryOnFailure: 'Ponawianie próby w przypadku niepowodzenia', + retryFailedTimes: '{{times}} ponawianie prób nie powiodło się', + ms: 'Ms', + }, }, start: { required: 'wymagane', diff --git a/web/i18n/pt-BR/common.ts b/web/i18n/pt-BR/common.ts index a2c74a7bee..6f66e65878 100644 --- a/web/i18n/pt-BR/common.ts +++ b/web/i18n/pt-BR/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Pode editar aplicativos, mas não pode gerenciar configurações da equipe', inviteTeamMember: 'Adicionar membro da equipe', inviteTeamMemberTip: 'Eles podem acessar os dados da sua equipe diretamente após fazer login.', + emailNotSetup: 'O servidor de e-mail não está configurado, então os e-mails de convite não podem ser enviados. Por favor, notifique os usuários sobre o link de convite que será emitido após o convite.', email: 'E-mail', emailInvalid: 'Formato de e-mail inválido', emailPlaceholder: 'Por favor, insira e-mails', diff --git a/web/i18n/pt-BR/tools.ts b/web/i18n/pt-BR/tools.ts index 1b20715328..8af475a98a 100644 --- a/web/i18n/pt-BR/tools.ts +++ b/web/i18n/pt-BR/tools.ts @@ -144,6 +144,8 @@ const translation = { emptyTitle: 'Nenhuma ferramenta de fluxo de trabalho disponível', added: 'Adicionado', manageInTools: 'Gerenciar em Ferramentas', + emptyTitleCustom: 'Nenhuma ferramenta personalizada disponível', + emptyTipCustom: 'Criar uma ferramenta personalizada', }, openInStudio: 'Abrir no Studio', customToolTip: 'Saiba mais sobre as ferramentas personalizadas da Dify', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index 4d53ec07c7..b99c64cdf4 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Tratamento de erros', tip: 'Estratégia de tratamento de exceções, disparada quando um nó encontra uma exceção.', }, + retry: { + retry: 'Repetir', + retryOnFailure: 'Tentar novamente em caso de falha', + maxRetries: 'Máximo de tentativas', + retryInterval: 'Intervalo de repetição', + retryTimes: 'Tente novamente {{times}} vezes em caso de falha', + retrying: 'Repetindo...', + retrySuccessful: 'Repetição bem-sucedida', + retryFailed: 'Falha na nova tentativa', + retryFailedTimes: '{{times}} tentativas falharam', + times: 'vezes', + ms: 'ms', + retries: '{{num}} Tentativas', + }, }, start: { required: 'requerido', diff --git a/web/i18n/ro-RO/common.ts b/web/i18n/ro-RO/common.ts index 27a0ab6bf3..0badaf5a13 100644 --- a/web/i18n/ro-RO/common.ts +++ b/web/i18n/ro-RO/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Poate construi aplicații, dar nu poate gestiona setările echipei', inviteTeamMember: 'Adaugă membru în echipă', inviteTeamMemberTip: 'Pot accesa direct datele echipei dvs. după autentificare.', + emailNotSetup: 'Serverul de e-mail nu este configurat, astfel încât e-mailurile de invitație nu pot fi trimise. Vă rugăm să notificați utilizatorii despre linkul de invitație care va fi emis după invitație.', email: 'Email', emailInvalid: 'Format de email invalid', emailPlaceholder: 'Vă rugăm să introduceți emailuri', diff --git a/web/i18n/ro-RO/tools.ts b/web/i18n/ro-RO/tools.ts index 165bdb26ed..baeffb2b66 100644 --- a/web/i18n/ro-RO/tools.ts +++ b/web/i18n/ro-RO/tools.ts @@ -144,6 +144,8 @@ const translation = { type: 'tip', emptyTitle: 'Nu este disponibil niciun instrument de flux de lucru', emptyTip: 'Accesați "Flux de lucru -> Publicați ca instrument"', + emptyTitleCustom: 'Nu este disponibil niciun instrument personalizat', + emptyTipCustom: 'Crearea unui instrument personalizat', }, openInStudio: 'Deschide în Studio', customToolTip: 'Aflați mai multe despre instrumentele personalizate Dify', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index 3dfa6d04ed..b142640c9b 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Gestionarea erorilor', tip: 'Strategie de gestionare a excepțiilor, declanșată atunci când un nod întâlnește o excepție.', }, + retry: { + retry: 'Reîncercare', + retryOnFailure: 'Reîncercați în caz de eșec', + maxRetries: 'numărul maxim de încercări', + retryInterval: 'Interval de reîncercare', + retrying: 'Reîncerca...', + retrySuccessful: 'Reîncercați cu succes', + retryFailed: 'Reîncercarea a eșuat', + retryFailedTimes: '{{times}} reîncercări eșuate', + times: 'Ori', + ms: 'Ms', + retries: '{{num}} Încercări', + retryTimes: 'Reîncercați {{times}} ori în caz de eșec', + }, }, start: { required: 'necesar', diff --git a/web/i18n/ru-RU/common.ts b/web/i18n/ru-RU/common.ts index 6d9edf97c1..64a7c9375d 100644 --- a/web/i18n/ru-RU/common.ts +++ b/web/i18n/ru-RU/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Может управлять только базой знаний', inviteTeamMember: 'Добавить участника команды', inviteTeamMemberTip: 'Они могут получить доступ к данным вашей команды сразу после входа в систему.', + emailNotSetup: 'Почтовый сервер не настроен, поэтому приглашения по электронной почте не могут быть отправлены. Пожалуйста, уведомите пользователей о ссылке для приглашения, которая будет выдана после приглашения.', email: 'Электронная почта', emailInvalid: 'Неверный формат электронной почты', emailPlaceholder: 'Пожалуйста, введите адреса электронной почты', diff --git a/web/i18n/ru-RU/tools.ts b/web/i18n/ru-RU/tools.ts index e0dfd571b2..4749fee163 100644 --- a/web/i18n/ru-RU/tools.ts +++ b/web/i18n/ru-RU/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Управлять в инструментах', emptyTitle: 'Нет доступных инструментов рабочего процесса', emptyTip: 'Перейдите в "Рабочий процесс -> Опубликовать как инструмент"', + emptyTitleCustom: 'Нет пользовательского инструмента', + emptyTipCustom: 'Создание пользовательского инструмента', }, createTool: { title: 'Создать пользовательский инструмент', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 600c59f2ed..49c43b4d6d 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Обработка ошибок', tip: 'Стратегия обработки исключений, запускаемая при обнаружении исключения на узле.', }, + retry: { + retry: 'Снова пробовать', + retryOnFailure: 'Повторная попытка при неудаче', + maxRetries: 'максимальное количество повторных попыток', + retryInterval: 'Интервал повторных попыток', + retryTimes: 'Повторите {{раз}} раз при неудаче', + retrying: 'Повтор...', + retrySuccessful: 'Повторить попытку успешно', + retryFailed: 'Повторная попытка не удалась', + times: 'раз', + ms: 'госпожа', + retryFailedTimes: 'Повторные попытки {{times}} не увенчались успехом', + retries: '{{число}} Повторных попыток', + }, }, start: { required: 'обязательно', diff --git a/web/i18n/sl-SI/common.ts b/web/i18n/sl-SI/common.ts index dc399bd3a4..0c5d1dfc4b 100644 --- a/web/i18n/sl-SI/common.ts +++ b/web/i18n/sl-SI/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Lahko upravlja samo bazo znanja', inviteTeamMember: 'Dodaj člana ekipe', inviteTeamMemberTip: 'Do vaših podatkov bo lahko dostopal takoj po prijavi.', + emailNotSetup: 'E-poštni strežnik ni nastavljen, zato vabil po e-pošti ni mogoče poslati. Prosimo, obvestite uporabnike o povezavi za povabilo, ki bo izdana po povabilu.', email: 'E-pošta', emailInvalid: 'Neveljaven format e-pošte', emailPlaceholder: 'Vnesite e-poštne naslove', diff --git a/web/i18n/sl-SI/tools.ts b/web/i18n/sl-SI/tools.ts index 57160cfe62..63b508a05d 100644 --- a/web/i18n/sl-SI/tools.ts +++ b/web/i18n/sl-SI/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Upravljaj v Orodjih', emptyTitle: 'Orodje za potek dela ni na voljo', emptyTip: 'Pojdite na "Potek dela -> Objavi kot orodje"', + emptyTipCustom: 'Ustvarjanje orodja po meri', + emptyTitleCustom: 'Orodje po meri ni na voljo', }, createTool: { title: 'Ustvari prilagojeno orodje', diff --git a/web/i18n/sl-SI/workflow.ts b/web/i18n/sl-SI/workflow.ts index 2c9dab8b55..7c40c25e92 100644 --- a/web/i18n/sl-SI/workflow.ts +++ b/web/i18n/sl-SI/workflow.ts @@ -759,6 +759,20 @@ const translation = { title: 'Ravnanje z napakami', tip: 'Strategija ravnanja z izjemami, ki se sproži, ko vozlišče naleti na izjemo.', }, + retry: { + retryOnFailure: 'Ponovni poskus ob neuspehu', + retryInterval: 'Interval ponovnega poskusa', + retrying: 'Ponovnim...', + retry: 'Ponoviti', + retryFailedTimes: '{{times}} ponovni poskusi niso uspeli', + retries: '{{num}} Poskusov', + times: 'Krat', + retryTimes: 'Ponovni poskus {{times}}-krat ob neuspehu', + retryFailed: 'Ponovni poskus ni uspel', + retrySuccessful: 'Ponovni poskus je bil uspešen', + maxRetries: 'Največ ponovnih poskusov', + ms: 'Ms', + }, }, start: { outputVars: { diff --git a/web/i18n/th-TH/common.ts b/web/i18n/th-TH/common.ts index 82eddae723..6aa1b30610 100644 --- a/web/i18n/th-TH/common.ts +++ b/web/i18n/th-TH/common.ts @@ -194,6 +194,7 @@ const translation = { datasetOperatorTip: 'สามารถจัดการฐานความรู้ได้เท่านั้น', inviteTeamMember: 'เพิ่มสมาชิกในทีม', inviteTeamMemberTip: 'พวกเขาสามารถเข้าถึงข้อมูลทีมของคุณได้โดยตรงหลังจากลงชื่อเข้าใช้', + emailNotSetup: 'เซิร์ฟเวอร์อีเมลไม่ได้ตั้งค่าไว้ จึงไม่สามารถส่งอีเมลเชิญได้ กรุณาแจ้งผู้ใช้เกี่ยวกับลิงก์เชิญที่จะออกหลังจากการเชิญแทน', email: 'อีเมล', emailInvalid: 'รูปแบบอีเมลไม่ถูกต้อง', emailPlaceholder: 'กรุณากรอกอีเมล', diff --git a/web/i18n/th-TH/tools.ts b/web/i18n/th-TH/tools.ts index a3e12bafd0..98272e83f5 100644 --- a/web/i18n/th-TH/tools.ts +++ b/web/i18n/th-TH/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'จัดการในเครื่องมือ', emptyTitle: 'ไม่มีเครื่องมือเวิร์กโฟลว์', emptyTip: 'ไปที่ "เวิร์กโฟลว์ -> เผยแพร่เป็นเครื่องมือ"', + emptyTitleCustom: 'ไม่มีเครื่องมือที่กําหนดเอง', + emptyTipCustom: 'สร้างเครื่องมือแบบกําหนดเอง', }, createTool: { title: 'สร้างเครื่องมือที่กําหนดเอง', diff --git a/web/i18n/th-TH/workflow.ts b/web/i18n/th-TH/workflow.ts index c4305466aa..b8d2e72de0 100644 --- a/web/i18n/th-TH/workflow.ts +++ b/web/i18n/th-TH/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'การจัดการข้อผิดพลาด', tip: 'กลยุทธ์การจัดการข้อยกเว้น ทริกเกอร์เมื่อโหนดพบข้อยกเว้น', }, + retry: { + retry: 'ลอง', + retryOnFailure: 'ลองใหม่เมื่อล้มเหลว', + maxRetries: 'การลองซ้ําสูงสุด', + retryInterval: 'ช่วงเวลาลองใหม่', + retryTimes: 'ลอง {{times}} ครั้งเมื่อล้มเหลว', + retrying: 'กําลังลองซ้ํา...', + retrySuccessful: 'ลองใหม่สําเร็จ', + retryFailed: 'ลองใหม่ล้มเหลว', + retryFailedTimes: '{{times}} การลองซ้ําล้มเหลว', + times: 'ครั้ง', + retries: '{{num}} ลอง', + ms: 'นางสาว', + }, }, start: { required: 'ต้องระบุ', diff --git a/web/i18n/tr-TR/app.ts b/web/i18n/tr-TR/app.ts index 1681fc9169..29c2aeaf45 100644 --- a/web/i18n/tr-TR/app.ts +++ b/web/i18n/tr-TR/app.ts @@ -112,7 +112,7 @@ const translation = { removeOriginal: 'Orijinal uygulamayı sil', switchStart: 'Geçişi Başlat', typeSelector: { - all: 'ALL Types', + all: 'All Types', chatbot: 'Chatbot', agent: 'Agent', workflow: 'Workflow', diff --git a/web/i18n/tr-TR/common.ts b/web/i18n/tr-TR/common.ts index 320517925a..9792f07e18 100644 --- a/web/i18n/tr-TR/common.ts +++ b/web/i18n/tr-TR/common.ts @@ -199,6 +199,7 @@ const translation = { datasetOperatorTip: 'Sadece bilgi tabanını yönetebilir', inviteTeamMember: 'Takım Üyesi Ekle', inviteTeamMemberTip: 'Giriş yaptıktan sonra takım verilerinize doğrudan erişebilirler.', + emailNotSetup: 'E-posta sunucusu kurulu değil, bu nedenle davet e-postaları gönderilemiyor. Lütfen kullanıcıları davetten sonra verilecek davet bağlantısı hakkında bilgilendirin.', email: 'E-posta', emailInvalid: 'Geçersiz E-posta Formatı', emailPlaceholder: 'Lütfen e-postaları girin', diff --git a/web/i18n/tr-TR/tools.ts b/web/i18n/tr-TR/tools.ts index 00af8ed7f2..a579ac82f1 100644 --- a/web/i18n/tr-TR/tools.ts +++ b/web/i18n/tr-TR/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: 'Araçlarda Yönet', emptyTitle: 'Kullanılabilir workflow aracı yok', emptyTip: 'Git "Workflow -> Araç olarak Yayınla"', + emptyTitleCustom: 'Özel bir araç yok', + emptyTipCustom: 'Özel bir araç oluşturun', }, createTool: { title: 'Özel Araç Oluştur', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 951a20e049..edec6a0b49 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Hata İşleme', tip: 'Bir düğüm bir özel durumla karşılaştığında tetiklenen özel durum işleme stratejisi.', }, + retry: { + retry: 'Yeni -den deneme', + retryOnFailure: 'Hata durumunda yeniden dene', + maxRetries: 'En fazla yeniden deneme', + times: 'kere', + retries: '{{sayı}} Yeni -den deneme', + retryFailed: 'Yeniden deneme başarısız oldu', + retryInterval: 'Yeniden deneme aralığı', + retryTimes: 'Hata durumunda {{times}} kez yeniden deneyin', + retryFailedTimes: '{{times}} yeniden denemeleri başarısız oldu', + retrySuccessful: 'Yeniden deneme başarılı', + retrying: 'Yeniden deneniyor...', + ms: 'Ms', + }, }, start: { required: 'gerekli', diff --git a/web/i18n/uk-UA/common.ts b/web/i18n/uk-UA/common.ts index 6bf6dafc4f..fbe9b67750 100644 --- a/web/i18n/uk-UA/common.ts +++ b/web/i18n/uk-UA/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Може створювати програми, але не може керувати налаштуваннями команди', inviteTeamMember: 'Додати учасника команди', inviteTeamMemberTip: 'Вони зможуть отримати доступ до даних вашої команди безпосередньо після входу.', + emailNotSetup: 'Поштовий сервер не налаштований, тому запрошення електронною поштою не можуть бути надіслані. Будь ласка, повідомте користувачів про посилання для запрошення, яке буде видано після запрошення.', email: 'Електронна пошта', emailInvalid: 'Недійсний формат електронної пошти', emailPlaceholder: 'Будь ласка, введіть адресу електронної пошти', diff --git a/web/i18n/uk-UA/tools.ts b/web/i18n/uk-UA/tools.ts index 309a450afc..f84d0d82cc 100644 --- a/web/i18n/uk-UA/tools.ts +++ b/web/i18n/uk-UA/tools.ts @@ -144,6 +144,8 @@ const translation = { manageInTools: 'Керування в інструментах', emptyTip: 'Перейдіть до розділу "Робочий процес -> Опублікувати як інструмент"', emptyTitle: 'Немає доступного інструменту для роботи з робочими процесами', + emptyTitleCustom: 'Немає доступного спеціального інструменту', + emptyTipCustom: 'Створення власного інструмента', }, openInStudio: 'Відкрити в Студії', customToolTip: 'Дізнайтеся більше про користувацькі інструменти Dify', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 2c00d3bf59..29fd9d8188 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Обробка помилок', tip: 'Стратегія обробки винятків, що спрацьовує, коли вузол стикається з винятком.', }, + retry: { + retry: 'Повторити', + retryOnFailure: 'повторити спробу в разі невдачі', + retryInterval: 'Інтервал повторних спроб', + retrying: 'Спроби...', + retryFailed: 'Повторна спроба не вдалася', + times: 'Разів', + ms: 'МС', + retries: '{{num}} Спроб', + maxRetries: 'Максимальна кількість повторних спроб', + retrySuccessful: 'Повторна спроба успішна', + retryFailedTimes: '{{times}} повторні спроби не вдалися', + retryTimes: 'Повторіть спробу {{times}} у разі невдачі', + }, }, start: { required: 'обов\'язковий', diff --git a/web/i18n/vi-VN/common.ts b/web/i18n/vi-VN/common.ts index bf5339f40e..8bafd86854 100644 --- a/web/i18n/vi-VN/common.ts +++ b/web/i18n/vi-VN/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: 'Có thể xây dựng ứng dụng, không thể quản lý cài đặt nhóm', inviteTeamMember: 'Mời thành viên nhóm', inviteTeamMemberTip: 'Sau khi đăng nhập, họ có thể truy cập trực tiếp vào dữ liệu nhóm của bạn.', + emailNotSetup: 'Máy chủ email chưa được thiết lập, vì vậy không thể gửi email mời. Vui lòng thông báo cho người dùng về liên kết mời sẽ được phát hành sau khi mời.', email: 'Email', emailInvalid: 'Định dạng Email không hợp lệ', emailPlaceholder: 'Vui lòng nhập email', diff --git a/web/i18n/vi-VN/tools.ts b/web/i18n/vi-VN/tools.ts index b03a6ccc98..86c55166f9 100644 --- a/web/i18n/vi-VN/tools.ts +++ b/web/i18n/vi-VN/tools.ts @@ -144,6 +144,8 @@ const translation = { added: 'Thêm', emptyTip: 'Đi tới "Quy trình làm việc -> Xuất bản dưới dạng công cụ"', emptyTitle: 'Không có sẵn công cụ quy trình làm việc', + emptyTitleCustom: 'Không có công cụ tùy chỉnh nào có sẵn', + emptyTipCustom: 'Tạo công cụ tùy chỉnh', }, toolNameUsageTip: 'Tên cuộc gọi công cụ để lý luận và nhắc nhở tổng đài viên', customToolTip: 'Tìm hiểu thêm về các công cụ tùy chỉnh Dify', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 956fe84159..9e16cb5347 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -322,6 +322,20 @@ const translation = { tip: 'Chiến lược xử lý ngoại lệ, được kích hoạt khi một nút gặp phải ngoại lệ.', title: 'Xử lý lỗi', }, + retry: { + retry: 'Thử lại', + maxRetries: 'Số lần thử lại tối đa', + retryInterval: 'Khoảng thời gian thử lại', + retryTimes: 'Thử lại {{lần}} lần khi không thành công', + retrying: 'Thử lại...', + retrySuccessful: 'Thử lại thành công', + retryFailed: 'Thử lại không thành công', + retryFailedTimes: '{{lần}} lần thử lại không thành công', + retries: '{{số}} Thử lại', + retryOnFailure: 'Thử lại khi không thành công', + times: 'lần', + ms: 'Ms', + }, }, start: { required: 'bắt buộc', diff --git a/web/i18n/zh-Hans/common.ts b/web/i18n/zh-Hans/common.ts index 70caaa976a..96e08a9337 100644 --- a/web/i18n/zh-Hans/common.ts +++ b/web/i18n/zh-Hans/common.ts @@ -197,6 +197,7 @@ const translation = { datasetOperatorTip: '只能管理知识库', inviteTeamMember: '添加团队成员', inviteTeamMemberTip: '对方在登录后可以访问你的团队数据。', + emailNotSetup: '由于邮件服务器未设置,无法发送邀请邮件。请将邀请后生成的邀请链接通知用户。', email: '邮箱', emailInvalid: '邮箱格式无效', emailPlaceholder: '输入邮箱', diff --git a/web/i18n/zh-Hans/tools.ts b/web/i18n/zh-Hans/tools.ts index 1473fc23d3..a788ef0abe 100644 --- a/web/i18n/zh-Hans/tools.ts +++ b/web/i18n/zh-Hans/tools.ts @@ -31,6 +31,8 @@ const translation = { manageInTools: '去工具列表管理', emptyTitle: '没有可用的工作流工具', emptyTip: '去 “工作流 -> 发布为工具” 添加', + emptyTitleCustom: '没有可用的自定义工具', + emptyTipCustom: '创建自定义工具', }, createTool: { title: '创建自定义工具', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 19cda33057..dfad9208e7 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -329,6 +329,20 @@ const translation = { tip: '流程中有 {{num}} 个节点运行异常,请前往追踪查看日志。', }, }, + retry: { + retry: '重试', + retryOnFailure: '失败时重试', + maxRetries: '最大重试次数', + retryInterval: '重试间隔', + retryTimes: '失败时重试 {{times}} 次', + retrying: '重试中...', + retrySuccessful: '重试成功', + retryFailed: '重试失败', + retryFailedTimes: '{{times}} 次重试失败', + times: '次', + ms: '毫秒', + retries: '{{num}} 重试次数', + }, }, start: { required: '必填', diff --git a/web/i18n/zh-Hant/common.ts b/web/i18n/zh-Hant/common.ts index 09c1e9d839..8340650993 100644 --- a/web/i18n/zh-Hant/common.ts +++ b/web/i18n/zh-Hant/common.ts @@ -191,6 +191,7 @@ const translation = { editorTip: '能夠建立並編輯應用程式,不能管理團隊設定', inviteTeamMember: '新增團隊成員', inviteTeamMemberTip: '對方在登入後可以訪問你的團隊資料。', + emailNotSetup: '由於郵件伺服器未設置,無法發送邀請郵件。請將邀請後生成的邀請連結通知用戶。', email: '郵箱', emailInvalid: '郵箱格式無效', emailPlaceholder: '輸入郵箱', diff --git a/web/i18n/zh-Hant/tools.ts b/web/i18n/zh-Hant/tools.ts index d45980c017..40a63eff65 100644 --- a/web/i18n/zh-Hant/tools.ts +++ b/web/i18n/zh-Hant/tools.ts @@ -144,6 +144,8 @@ const translation = { category: '類別', emptyTitle: '沒有可用的工作流程工具', emptyTip: '轉到“工作流 - >發佈為工具”', + emptyTipCustom: '創建自訂工具', + emptyTitleCustom: '沒有可用的自訂工具', }, customToolTip: '瞭解有關 Dify 自訂工具的更多資訊', toolNameUsageTip: '用於代理推理和提示的工具調用名稱', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 4bbbf7a04f..a78c6a2f04 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: '錯誤處理', tip: '異常處理策略,當節點遇到異常時觸發。', }, + retry: { + retry: '重試', + retryOnFailure: '失敗時重試', + maxRetries: '最大重試次數', + retryInterval: '重試間隔', + retryTimes: '失敗時重試 {{times}} 次', + retrying: '重試。。。', + retrySuccessful: '重試成功', + retryFailed: '重試失敗', + retryFailedTimes: '{{times}} 次重試失敗', + times: '次', + ms: '女士', + retries: '{{num}}重試', + }, }, start: { required: '必填', diff --git a/web/package.json b/web/package.json index 44e92806d3..d9515645c8 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.13.2", + "version": "0.14.2", "private": true, "engines": { "node": ">=18.17.0" @@ -50,7 +50,7 @@ "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", "dayjs": "^1.11.7", - "echarts": "^5.4.1", + "echarts": "^5.5.1", "echarts-for-react": "^3.0.2", "elkjs": "^0.9.3", "emoji-mart": "^5.5.2", diff --git a/web/public/screenshots/Light/Agent.png b/web/public/screenshots/Light/Agent.png deleted file mode 100644 index fe596a555f..0000000000 Binary files a/web/public/screenshots/Light/Agent.png and /dev/null differ diff --git a/web/public/screenshots/Light/Agent@2x.png b/web/public/screenshots/Light/Agent@2x.png deleted file mode 100644 index dda71b29e9..0000000000 Binary files a/web/public/screenshots/Light/Agent@2x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Agent@3x.png b/web/public/screenshots/Light/Agent@3x.png deleted file mode 100644 index 0d05644eab..0000000000 Binary files a/web/public/screenshots/Light/Agent@3x.png and /dev/null differ diff --git a/web/public/screenshots/Light/ChatFlow.png b/web/public/screenshots/Light/ChatFlow.png deleted file mode 100644 index 1753de7763..0000000000 Binary files a/web/public/screenshots/Light/ChatFlow.png and /dev/null differ diff --git a/web/public/screenshots/Light/ChatFlow@2x.png b/web/public/screenshots/Light/ChatFlow@2x.png deleted file mode 100644 index 6b72a8d732..0000000000 Binary files a/web/public/screenshots/Light/ChatFlow@2x.png and /dev/null differ diff --git a/web/public/screenshots/Light/ChatFlow@3x.png b/web/public/screenshots/Light/ChatFlow@3x.png deleted file mode 100644 index 7a059af6a4..0000000000 Binary files a/web/public/screenshots/Light/ChatFlow@3x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Chatbot.png b/web/public/screenshots/Light/Chatbot.png deleted file mode 100644 index b628a930fb..0000000000 Binary files a/web/public/screenshots/Light/Chatbot.png and /dev/null differ diff --git a/web/public/screenshots/Light/Chatbot@2x.png b/web/public/screenshots/Light/Chatbot@2x.png deleted file mode 100644 index 048a9f9036..0000000000 Binary files a/web/public/screenshots/Light/Chatbot@2x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Chatbot@3x.png b/web/public/screenshots/Light/Chatbot@3x.png deleted file mode 100644 index 9b7c1f5999..0000000000 Binary files a/web/public/screenshots/Light/Chatbot@3x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Chatflow.png b/web/public/screenshots/Light/Chatflow.png deleted file mode 100644 index 1753de7763..0000000000 Binary files a/web/public/screenshots/Light/Chatflow.png and /dev/null differ diff --git a/web/public/screenshots/Light/Chatflow@2x.png b/web/public/screenshots/Light/Chatflow@2x.png deleted file mode 100644 index 6b72a8d732..0000000000 Binary files a/web/public/screenshots/Light/Chatflow@2x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Chatflow@3x.png b/web/public/screenshots/Light/Chatflow@3x.png deleted file mode 100644 index 7a059af6a4..0000000000 Binary files a/web/public/screenshots/Light/Chatflow@3x.png and /dev/null differ diff --git a/web/public/screenshots/Light/TextGenerator.png b/web/public/screenshots/Light/TextGenerator.png deleted file mode 100644 index 14973451cc..0000000000 Binary files a/web/public/screenshots/Light/TextGenerator.png and /dev/null differ diff --git a/web/public/screenshots/Light/TextGenerator@2x.png b/web/public/screenshots/Light/TextGenerator@2x.png deleted file mode 100644 index 7e1baae97b..0000000000 Binary files a/web/public/screenshots/Light/TextGenerator@2x.png and /dev/null differ diff --git a/web/public/screenshots/Light/TextGenerator@3x.png b/web/public/screenshots/Light/TextGenerator@3x.png deleted file mode 100644 index 746e9ac1be..0000000000 Binary files a/web/public/screenshots/Light/TextGenerator@3x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Workflow.png b/web/public/screenshots/Light/Workflow.png deleted file mode 100644 index a82c9a6a4d..0000000000 Binary files a/web/public/screenshots/Light/Workflow.png and /dev/null differ diff --git a/web/public/screenshots/Light/Workflow@2x.png b/web/public/screenshots/Light/Workflow@2x.png deleted file mode 100644 index 0a1a19435b..0000000000 Binary files a/web/public/screenshots/Light/Workflow@2x.png and /dev/null differ diff --git a/web/public/screenshots/Light/Workflow@3x.png b/web/public/screenshots/Light/Workflow@3x.png deleted file mode 100644 index 914ce45003..0000000000 Binary files a/web/public/screenshots/Light/Workflow@3x.png and /dev/null differ diff --git a/web/public/screenshots/light/Agent.png b/web/public/screenshots/light/Agent.png deleted file mode 100644 index fe596a555f..0000000000 Binary files a/web/public/screenshots/light/Agent.png and /dev/null differ diff --git a/web/public/screenshots/light/Agent@2x.png b/web/public/screenshots/light/Agent@2x.png deleted file mode 100644 index dda71b29e9..0000000000 Binary files a/web/public/screenshots/light/Agent@2x.png and /dev/null differ diff --git a/web/public/screenshots/light/Agent@3x.png b/web/public/screenshots/light/Agent@3x.png deleted file mode 100644 index 0d05644eab..0000000000 Binary files a/web/public/screenshots/light/Agent@3x.png and /dev/null differ diff --git a/web/public/screenshots/light/Chatbot.png b/web/public/screenshots/light/Chatbot.png deleted file mode 100644 index b628a930fb..0000000000 Binary files a/web/public/screenshots/light/Chatbot.png and /dev/null differ diff --git a/web/public/screenshots/light/Chatbot@2x.png b/web/public/screenshots/light/Chatbot@2x.png deleted file mode 100644 index 048a9f9036..0000000000 Binary files a/web/public/screenshots/light/Chatbot@2x.png and /dev/null differ diff --git a/web/public/screenshots/light/Chatbot@3x.png b/web/public/screenshots/light/Chatbot@3x.png deleted file mode 100644 index 9b7c1f5999..0000000000 Binary files a/web/public/screenshots/light/Chatbot@3x.png and /dev/null differ diff --git a/web/public/screenshots/light/Chatflow.png b/web/public/screenshots/light/Chatflow.png deleted file mode 100644 index 1753de7763..0000000000 Binary files a/web/public/screenshots/light/Chatflow.png and /dev/null differ diff --git a/web/public/screenshots/light/Chatflow@2x.png b/web/public/screenshots/light/Chatflow@2x.png deleted file mode 100644 index 6b72a8d732..0000000000 Binary files a/web/public/screenshots/light/Chatflow@2x.png and /dev/null differ diff --git a/web/public/screenshots/light/Chatflow@3x.png b/web/public/screenshots/light/Chatflow@3x.png deleted file mode 100644 index 7a059af6a4..0000000000 Binary files a/web/public/screenshots/light/Chatflow@3x.png and /dev/null differ diff --git a/web/public/screenshots/light/TextGenerator.png b/web/public/screenshots/light/TextGenerator.png deleted file mode 100644 index 14973451cc..0000000000 Binary files a/web/public/screenshots/light/TextGenerator.png and /dev/null differ diff --git a/web/public/screenshots/light/TextGenerator@2x.png b/web/public/screenshots/light/TextGenerator@2x.png deleted file mode 100644 index 7e1baae97b..0000000000 Binary files a/web/public/screenshots/light/TextGenerator@2x.png and /dev/null differ diff --git a/web/public/screenshots/light/TextGenerator@3x.png b/web/public/screenshots/light/TextGenerator@3x.png deleted file mode 100644 index 746e9ac1be..0000000000 Binary files a/web/public/screenshots/light/TextGenerator@3x.png and /dev/null differ diff --git a/web/public/screenshots/light/Workflow.png b/web/public/screenshots/light/Workflow.png deleted file mode 100644 index a82c9a6a4d..0000000000 Binary files a/web/public/screenshots/light/Workflow.png and /dev/null differ diff --git a/web/public/screenshots/light/Workflow@2x.png b/web/public/screenshots/light/Workflow@2x.png deleted file mode 100644 index 0a1a19435b..0000000000 Binary files a/web/public/screenshots/light/Workflow@2x.png and /dev/null differ diff --git a/web/public/screenshots/light/Workflow@3x.png b/web/public/screenshots/light/Workflow@3x.png deleted file mode 100644 index 914ce45003..0000000000 Binary files a/web/public/screenshots/light/Workflow@3x.png and /dev/null differ diff --git a/web/service/base.ts b/web/service/base.ts index 03421d92a4..22b1a43ad1 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -62,6 +62,7 @@ export type IOnNodeStarted = (nodeStarted: NodeStartedResponse) => void export type IOnNodeFinished = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationStarted = (workflowStarted: IterationStartedResponse) => void export type IOnIterationNext = (workflowStarted: IterationNextResponse) => void +export type IOnNodeRetry = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationFinished = (workflowFinished: IterationFinishedResponse) => void export type IOnParallelBranchStarted = (parallelBranchStarted: ParallelBranchStartedResponse) => void export type IOnParallelBranchFinished = (parallelBranchFinished: ParallelBranchFinishedResponse) => void @@ -92,6 +93,7 @@ export type IOtherOptions = { onIterationStart?: IOnIterationStarted onIterationNext?: IOnIterationNext onIterationFinish?: IOnIterationFinished + onNodeRetry?: IOnNodeRetry onParallelBranchStarted?: IOnParallelBranchStarted onParallelBranchFinished?: IOnParallelBranchFinished onTextChunk?: IOnTextChunk @@ -165,6 +167,7 @@ const handleStream = ( onIterationStart?: IOnIterationStarted, onIterationNext?: IOnIterationNext, onIterationFinish?: IOnIterationFinished, + onNodeRetry?: IOnNodeRetry, onParallelBranchStarted?: IOnParallelBranchStarted, onParallelBranchFinished?: IOnParallelBranchFinished, onTextChunk?: IOnTextChunk, @@ -256,6 +259,9 @@ const handleStream = ( else if (bufferObj.event === 'iteration_completed') { onIterationFinish?.(bufferObj as IterationFinishedResponse) } + else if (bufferObj.event === 'node_retry') { + onNodeRetry?.(bufferObj as NodeFinishedResponse) + } else if (bufferObj.event === 'parallel_branch_started') { onParallelBranchStarted?.(bufferObj as ParallelBranchStartedResponse) } @@ -462,6 +468,7 @@ export const ssePost = ( onIterationStart, onIterationNext, onIterationFinish, + onNodeRetry, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, @@ -533,7 +540,7 @@ export const ssePost = ( return } onData?.(str, isFirstMessage, moreInfo) - }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) + }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onNodeRetry, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) }).catch((e) => { if (e.toString() !== 'AbortError: The user aborted a request.' && !e.toString().errorMessage.includes('TypeError: Cannot assign to read only property')) Toast.notify({ type: 'error', message: e }) diff --git a/web/service/use-workflow.ts b/web/service/use-workflow.ts new file mode 100644 index 0000000000..948a114b04 --- /dev/null +++ b/web/service/use-workflow.ts @@ -0,0 +1,12 @@ +import { useQuery } from '@tanstack/react-query' +import { get } from './base' +import type { WorkflowConfigResponse } from '@/types/workflow' + +const NAME_SPACE = 'workflow' + +export const useWorkflowConfig = (appId: string) => { + return useQuery({ + queryKey: [NAME_SPACE, 'config', appId], + queryFn: () => get(`/apps/${appId}/workflows/draft/config`), + }) +} diff --git a/web/tailwind.config.js b/web/tailwind.config.js index a109d193d7..b4e2263167 100644 --- a/web/tailwind.config.js +++ b/web/tailwind.config.js @@ -93,6 +93,7 @@ module.exports = { 'chatbot-bg': 'var(--color-chatbot-bg)', 'chat-bubble-bg': 'var(--color-chat-bubble-bg)', 'workflow-process-bg': 'var(--color-workflow-process-bg)', + 'mask-top2bottom-gray-50-to-transparent': 'var(--mask-top2bottom-gray-50-to-transparent)', }, }, }, diff --git a/web/themes/manual-dark.css b/web/themes/manual-dark.css index 047554f4b4..4052e5566d 100644 --- a/web/themes/manual-dark.css +++ b/web/themes/manual-dark.css @@ -1,5 +1,27 @@ html[data-theme="dark"] { - --color-chatbot-bg: linear-gradient(180deg, rgba(34, 34, 37, 0.90) 0%, rgba(29, 29, 32, 0.90) 90.48%); - --color-chat-bubble-bg: linear-gradient(180deg, rgba(200, 206, 218, 0.08) 0%, rgba(200, 206, 218, 0.02) 100%); - --color-workflow-process-bg: linear-gradient(90deg, rgba(24, 24, 27, 0.25) 0%, rgba(24, 24, 27, 0.04) 100%); -} \ No newline at end of file + --color-chatbot-bg: linear-gradient( + 180deg, + rgba(34, 34, 37, 0.9) 0%, + rgba(29, 29, 32, 0.9) 90.48% + ); + --color-chat-bubble-bg: linear-gradient( + 180deg, + rgba(200, 206, 218, 0.08) 0%, + rgba(200, 206, 218, 0.02) 100% + ); + --color-workflow-process-bg: linear-gradient( + 90deg, + rgba(24, 24, 27, 0.25) 0%, + rgba(24, 24, 27, 0.04) 100% + ); + --color-account-teams-bg: linear-gradient( + 271deg, + rgba(34, 34, 37, 0.9) -0.1%, + rgba(29, 29, 32, 0.9) 98.26% + ); + --mask-top2bottom-gray-50-to-transparent: linear-gradient( + 180deg, + rgba(24, 24, 27, 0.08) 0%, + rgba(0, 0, 0, 0) 100% + ); +} diff --git a/web/themes/manual-light.css b/web/themes/manual-light.css index 09b9338184..303963c55d 100644 --- a/web/themes/manual-light.css +++ b/web/themes/manual-light.css @@ -1,5 +1,27 @@ html[data-theme="light"] { - --color-chatbot-bg: linear-gradient(180deg, rgba(249, 250, 251, 0.90) 0%, rgba(242, 244, 247, 0.90) 90.48%); - --color-chat-bubble-bg: linear-gradient(180deg, #FFF 0%, rgba(255, 255, 255, 0.60) 100%); - --color-workflow-process-bg: linear-gradient(90deg, rgba(200, 206, 218, 0.20) 0%, rgba(200, 206, 218, 0.04) 100%); -} \ No newline at end of file + --color-chatbot-bg: linear-gradient( + 180deg, + rgba(249, 250, 251, 0.9) 0%, + rgba(242, 244, 247, 0.9) 90.48% + ); + --color-chat-bubble-bg: linear-gradient( + 180deg, + #fff 0%, + rgba(255, 255, 255, 0.6) 100% + ); + --color-workflow-process-bg: linear-gradient( + 90deg, + rgba(200, 206, 218, 0.2) 0%, + rgba(200, 206, 218, 0.04) 100% + ); + --color-account-teams-bg: linear-gradient( + 271deg, + rgba(249, 250, 251, 0.9) -0.1%, + rgba(242, 244, 247, 0.9) 98.26% + ); + --mask-top2bottom-gray-50-to-transparent: linear-gradient( + 180deg, + rgba(200, 206, 218, 0.2) 0%, + rgba(255, 255, 255, 0) 100% + ); +} diff --git a/web/themes/tailwind-theme-var-define.ts b/web/themes/tailwind-theme-var-define.ts index 6329ce3d26..ea5f80b88b 100644 --- a/web/themes/tailwind-theme-var-define.ts +++ b/web/themes/tailwind-theme-var-define.ts @@ -399,6 +399,7 @@ const vars = { 'background-default-burn': 'var(--color-background-default-burn)', 'background-overlay-fullscreen': 'var(--color-background-overlay-fullscreen)', 'background-default-lighter': 'var(--color-background-default-lighter)', + 'background-account-teams-bg': 'var(--color-account-teams-bg)', 'background-section': 'var(--color-background-section)', 'background-interaction-from-bg-1': 'var(--color-background-interaction-from-bg-1)', 'background-interaction-from-bg-2': 'var(--color-background-interaction-from-bg-2)', diff --git a/web/types/feature.ts b/web/types/feature.ts index 47e8e1aad1..053ce3d7c9 100644 --- a/web/types/feature.ts +++ b/web/types/feature.ts @@ -29,6 +29,7 @@ export type SystemFeatures = { enable_social_oauth_login: boolean is_allow_create_workspace: boolean is_allow_register: boolean + is_email_setup: boolean license: License } @@ -43,6 +44,7 @@ export const defaultSystemFeatures: SystemFeatures = { enable_social_oauth_login: false, is_allow_create_workspace: false, is_allow_register: false, + is_email_setup: false, license: { status: LicenseStatus.NONE, expired_at: '', diff --git a/web/types/workflow.ts b/web/types/workflow.ts index a5db7e635d..cd6e9cfa5f 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -52,10 +52,12 @@ export type NodeTracing = { extras?: any expand?: boolean // for UI details?: NodeTracing[][] // iteration detail + retryDetail?: NodeTracing[] // retry detail parallel_id?: string parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + retry_index?: number } export type FetchWorkflowDraftResponse = { @@ -178,6 +180,7 @@ export type NodeFinishedResponse = { } created_at: number files?: FileResponse[] + retry_index?: number } } @@ -333,3 +336,7 @@ export type ConversationVariableResponse = { } export type IterationDurationMap = Record + +export type WorkflowConfigResponse = { + parallel_depth_limit: number +} diff --git a/web/yarn.lock b/web/yarn.lock index 6389f985c7..c9590d6b06 100644 --- a/web/yarn.lock +++ b/web/yarn.lock @@ -5903,13 +5903,13 @@ echarts-for-react@^3.0.2: fast-deep-equal "^3.1.3" size-sensor "^1.0.1" -echarts@^5.4.1: - version "5.4.2" - resolved "https://registry.npmjs.org/echarts/-/echarts-5.4.2.tgz" - integrity sha512-2W3vw3oI2tWJdyAz+b8DuWS0nfXtSDqlDmqgin/lfzbkB01cuMEN66KWBlmur3YMp5nEDEEt5s23pllnAzB4EA== +echarts@^5.5.1: + version "5.5.1" + resolved "https://registry.yarnpkg.com/echarts/-/echarts-5.5.1.tgz#8dc9c68d0c548934bedcb5f633db07ed1dd2101c" + integrity sha512-Fce8upazaAXUVUVsjgV6mBnGuqgO+JNDlcgF79Dksy4+wgGpQB2lmYoO4TSweFg/mZITdpGHomw/cNBJZj1icA== dependencies: tslib "2.3.0" - zrender "5.4.3" + zrender "5.6.0" electron-to-chromium@^1.5.41: version "1.5.52" @@ -9882,9 +9882,9 @@ nan@^2.17.0: integrity sha512-nbajikzWTMwsW+eSsNm3QwlOs7het9gGJU5dDZzRTQGk03vyBOauxgI4VakDzE0PtsGTmXPsXTbbjVhRwR5mpw== nanoid@^3.3.6, nanoid@^3.3.7: - version "3.3.7" - resolved "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz" - integrity sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g== + version "3.3.8" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.8.tgz#b1be3030bee36aaff18bacb375e5cce521684baf" + integrity sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w== natural-compare-lite@^1.4.0: version "1.4.0" @@ -13330,10 +13330,10 @@ zod@^3.23.6: resolved "https://registry.npmjs.org/zod/-/zod-3.23.8.tgz" integrity sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g== -zrender@5.4.3: - version "5.4.3" - resolved "https://registry.npmjs.org/zrender/-/zrender-5.4.3.tgz" - integrity sha512-DRUM4ZLnoaT0PBVvGBDO9oWIDBKFdAVieNWxWwK0niYzJCMwGchRk21/hsE+RKkIveH3XHCyvXcJDkgLVvfizQ== +zrender@5.6.0: + version "5.6.0" + resolved "https://registry.yarnpkg.com/zrender/-/zrender-5.6.0.tgz#01325b0bb38332dd5e87a8dbee7336cafc0f4a5b" + integrity sha512-uzgraf4njmmHAbEUxMJ8Oxg+P3fT04O+9p7gY+wJRVxo8Ge+KmYv0WJev945EH4wFuc4OY2NLXz46FZrWS9xJg== dependencies: tslib "2.3.0"