diff --git a/.gitignore b/.gitignore index 2f44cf7934..97b7333dde 100644 --- a/.gitignore +++ b/.gitignore @@ -174,5 +174,6 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json pyrightconfig.json +api/.vscode .idea/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ce7ad7db98..f810584f24 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install. -Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot. +Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot. ### 5. Visit dify in your browser diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 08fd34e117..303c2513f5 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -77,7 +77,7 @@ Dify 依赖以下工具和库: Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。 -查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。 +查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。 ### 5. 在浏览器中访问 Dify diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index e8f5456a3c..1ce8436a78 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。 次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。 -よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。 +よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。 ### 5. ブラウザで dify にアクセスする diff --git a/api/.env.example b/api/.env.example index 228218be0d..474798cef7 100644 --- a/api/.env.example +++ b/api/.env.example @@ -256,3 +256,7 @@ WORKFLOW_CALL_MAX_DEPTH=5 # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 + + +# Celery beat configuration +CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file diff --git a/api/app.py b/api/app.py index f5a6d40e1a..2c484ace85 100644 --- a/api/app.py +++ b/api/app.py @@ -1,7 +1,5 @@ import os -from configs import dify_config - if os.environ.get("DEBUG", "false").lower() != 'true': from gevent import monkey @@ -23,7 +21,9 @@ from flask import Flask, Response, request from flask_cors import CORS from werkzeug.exceptions import Unauthorized +import contexts from commands import register_commands +from configs import dify_config # DO NOT REMOVE BELOW from events import event_handlers @@ -181,7 +181,10 @@ def load_user_from_request(request_from_flask_login): decoded = PassportService().verify(auth_token) user_id = decoded.get('user_id') - return AccountService.load_logged_in_account(account_id=user_id, token=auth_token) + account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) + if account: + contexts.tenant_id.set(account.current_tenant_id) + return account @login_manager.unauthorized_handler diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index c000c3a0f2..369b25d788 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -23,6 +23,7 @@ class SecurityConfig(BaseSettings): default=24, ) + class AppExecutionConfig(BaseSettings): """ App Execution configs @@ -405,7 +406,6 @@ class DataSetConfig(BaseSettings): default=False, ) - class WorkspaceConfig(BaseSettings): """ Workspace configs @@ -435,6 +435,13 @@ class ImageFormatConfig(BaseSettings): ) +class CeleryBeatConfig(BaseSettings): + CELERY_BEAT_SCHEDULER_TIME: int = Field( + description='the time of the celery scheduler, default to 1 day', + default=1, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -462,5 +469,6 @@ class FeatureConfig( # hosted services config HostedServiceConfig, + CeleryBeatConfig, ): pass diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 209d46bb76..88fe188587 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings): default=False, ) - HOSTED_OPENAI_API_KEY: Optional[str] = Field( + HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( description='', default=None, ) diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index e513cad0e8..895cd6f176 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, Field, PositiveInt @@ -8,32 +7,32 @@ class MyScaleConfig(BaseModel): MyScale configs """ - MYSCALE_HOST: Optional[str] = Field( + MYSCALE_HOST: str = Field( description='MyScale host', - default=None, + default='localhost', ) - MYSCALE_PORT: Optional[PositiveInt] = Field( + MYSCALE_PORT: PositiveInt = Field( description='MyScale port', default=8123, ) - MYSCALE_USER: Optional[str] = Field( + MYSCALE_USER: str = Field( description='MyScale user', - default=None, + default='default', ) - MYSCALE_PASSWORD: Optional[str] = Field( + MYSCALE_PASSWORD: str = Field( description='MyScale password', - default=None, + default='', ) - MYSCALE_DATABASE: Optional[str] = Field( + MYSCALE_DATABASE: str = Field( description='MyScale database name', - default=None, + default='default', ) - MYSCALE_FTS_PARAMS: Optional[str] = Field( + MYSCALE_FTS_PARAMS: str = Field( description='MyScale fts index parameters', - default=None, + default='', ) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index e69de29bb2..08a2786994 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -0,0 +1,2 @@ +# TODO: Update all string in code to use this constant +HIDDEN_VALUE = '[__HIDDEN__]' \ No newline at end of file diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py new file mode 100644 index 0000000000..306fac3a93 --- /dev/null +++ b/api/contexts/__init__.py @@ -0,0 +1,3 @@ +from contextvars import ContextVar + +tenant_id: ContextVar[str] = ContextVar('tenant_id') \ No newline at end of file diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 1c42a57d43..2f304b970c 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -212,7 +212,7 @@ class AppCopyApi(Resource): parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() - data = AppDslService.export_dsl(app_model=app_model) + data = AppDslService.export_dsl(app_model=app_model, include_secret=True) app = AppDslService.import_and_create_new_app( tenant_id=current_user.current_tenant_id, data=data, @@ -234,8 +234,13 @@ class AppExportApi(Resource): if not current_user.is_editor: raise Forbidden() + # Add include_secret params + parser = reqparse.RequestParser() + parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') + args = parser.parse_args() + return { - "data": AppDslService.export_dsl(app_model=app_model) + "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) } diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 9f745ca120..686ef7b4be 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -13,6 +13,7 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.segments import factory from core.errors.error import AppInvokeQuotaExceededError from fields.workflow_fields import workflow_fields from fields.workflow_run_fields import workflow_run_node_execution_fields @@ -41,7 +42,7 @@ class DraftWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # fetch draft workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_draft_workflow(app_model=app_model) @@ -64,13 +65,15 @@ class DraftWorkflowApi(Resource): if not current_user.is_editor: raise Forbidden() - content_type = request.headers.get('Content-Type') + content_type = request.headers.get('Content-Type', '') if 'application/json' in content_type: parser = reqparse.RequestParser() parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('features', type=dict, required=True, nullable=False, location='json') parser.add_argument('hash', type=str, required=False, location='json') + # TODO: set this to required=True after frontend is updated + parser.add_argument('environment_variables', type=list, required=False, location='json') args = parser.parse_args() elif 'text/plain' in content_type: try: @@ -84,7 +87,8 @@ class DraftWorkflowApi(Resource): args = { 'graph': data.get('graph'), 'features': data.get('features'), - 'hash': data.get('hash') + 'hash': data.get('hash'), + 'environment_variables': data.get('environment_variables') } except json.JSONDecodeError: return {'message': 'Invalid JSON data'}, 400 @@ -94,12 +98,15 @@ class DraftWorkflowApi(Resource): workflow_service = WorkflowService() try: + environment_variables_list = args.get('environment_variables') or [] + environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=args.get('graph'), - features=args.get('features'), + graph=args['graph'], + features=args['features'], unique_hash=args.get('hash'), - account=current_user + account=current_user, + environment_variables=environment_variables, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 70c506bb0e..934b6413ae 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,10 +1,11 @@ import flask_restful -from flask import current_app, request +from flask import request from flask_login import current_user from flask_restful import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound import services +from configs import dify_config from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError @@ -530,7 +531,7 @@ class DatasetApiBaseUrlApi(Resource): @account_initialization_required def get(self): return { - 'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] + 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip('/')) + '/v1' } @@ -540,20 +541,20 @@ class DatasetRetrievalSettingApi(Resource): @login_required @account_initialization_required def get(self): - vector_type = current_app.config['VECTOR_STORE'] + vector_type = dify_config.VECTOR_STORE match vector_type: case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH + RetrievalMethod.SEMANTIC_SEARCH.value ] } case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH, - RetrievalMethod.FULL_TEXT_SEARCH, - RetrievalMethod.HYBRID_SEARCH, + RetrievalMethod.SEMANTIC_SEARCH.value, + RetrievalMethod.FULL_TEXT_SEARCH.value, + RetrievalMethod.HYBRID_SEARCH.value, ] } case _: @@ -569,15 +570,15 @@ class DatasetRetrievalSettingMockApi(Resource): case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH + RetrievalMethod.SEMANTIC_SEARCH.value ] } case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE: return { 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH, - RetrievalMethod.FULL_TEXT_SEARCH, - RetrievalMethod.HYBRID_SEARCH, + RetrievalMethod.SEMANTIC_SEARCH.value, + RetrievalMethod.FULL_TEXT_SEARCH.value, + RetrievalMethod.HYBRID_SEARCH.value, ] } case _: diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a189aac3f1..3dcade6152 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource): ) if last_id is not None: - last_segment = DocumentSegment.query.get(str(last_id)) + last_segment = db.session.get(DocumentSegment, str(last_id)) if last_segment: query = query.filter( DocumentSegment.position > last_segment.position) diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index c13bd45abb..3b2083bcc3 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,8 +1,9 @@ -from flask import current_app, request +from flask import request from flask_login import current_user from flask_restful import Resource, marshal_with import services +from configs import dify_config from controllers.console import api from controllers.console.datasets.error import ( FileTooLargeError, @@ -26,9 +27,9 @@ class FileApi(Resource): @account_initialization_required @marshal_with(upload_config_fields) def get(self): - file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") - batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT") - image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") + file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT + batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT + image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT return { 'file_size_limit': file_size_limit, 'batch_count_limit': batch_count_limit, @@ -76,7 +77,7 @@ class FileSupportTypeApi(Resource): @login_required @account_initialization_required def get(self): - etl_type = current_app.config['ETL_TYPE'] + etl_type = dify_config.ETL_TYPE allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS return {'allowed_extensions': allowed_extensions} diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 920b1d8383..27cc83042a 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -78,10 +78,12 @@ class ChatTextApi(InstalledAppResource): parser = reqparse.RequestParser() parser.add_argument('message_id', type=str, required=False, location='json') parser.add_argument('voice', type=str, location='json') + parser.add_argument('text', type=str, location='json') parser.add_argument('streaming', type=bool, location='json') args = parser.parse_args() - message_id = args.get('message_id') + message_id = args.get('message_id', None) + text = args.get('text', None) if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] and app_model.workflow and app_model.workflow.features_dict): @@ -95,7 +97,8 @@ class ChatTextApi(InstalledAppResource): response = AudioService.transcript_tts( app_model=app_model, message_id=message_id, - voice=voice + voice=voice, + text=text ) return response except services.errors.app_model_config.AppModelConfigBrokenError: diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 45255edb3a..0a168d6306 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,7 +1,7 @@ -from flask import current_app from flask_restful import fields, marshal_with +from configs import dify_config from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource @@ -78,7 +78,7 @@ class AppParameterApi(InstalledAppResource): "transfer_methods": ["remote_url", "local_file"] }}), 'system_parameters': { - 'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT') + 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT } } diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index b319f706b4..6feb1003a9 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,8 +1,9 @@ import os -from flask import current_app, session +from flask import session from flask_restful import Resource, reqparse +from configs import dify_config from libs.helper import str_len from models.model import DifySetup from services.account_service import TenantService @@ -40,7 +41,7 @@ class InitValidateAPI(Resource): return {'result': 'success'}, 201 def get_init_validate_status(): - if current_app.config['EDITION'] == 'SELF_HOSTED': + if dify_config.EDITION == 'SELF_HOSTED': if os.environ.get('INIT_PASSWORD'): return session.get('is_init_validated') or DifySetup.query.first() diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index def50212a1..ef7cc6bc03 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,8 +1,9 @@ from functools import wraps -from flask import current_app, request +from flask import request from flask_restful import Resource, reqparse +from configs import dify_config from libs.helper import email, get_remote_ip, str_len from libs.password import valid_password from models.model import DifySetup @@ -17,7 +18,7 @@ from .wraps import only_edition_self_hosted class SetupApi(Resource): def get(self): - if current_app.config['EDITION'] == 'SELF_HOSTED': + if dify_config.EDITION == 'SELF_HOSTED': setup_status = get_setup_status() if setup_status: return { @@ -77,7 +78,7 @@ def setup_required(view): def get_setup_status(): - if current_app.config['EDITION'] == 'SELF_HOSTED': + if dify_config.EDITION == 'SELF_HOSTED': return DifySetup.query.first() else: return True diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index faf36c4f40..1fcf4bdc00 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -3,9 +3,10 @@ import json import logging import requests -from flask import current_app from flask_restful import Resource, reqparse +from configs import dify_config + from . import api @@ -15,16 +16,16 @@ class VersionApi(Resource): parser = reqparse.RequestParser() parser.add_argument('current_version', type=str, required=True, location='args') args = parser.parse_args() - check_update_url = current_app.config['CHECK_UPDATE_URL'] + check_update_url = dify_config.CHECK_UPDATE_URL result = { - 'version': current_app.config['CURRENT_VERSION'], + 'version': dify_config.CURRENT_VERSION, 'release_date': '', 'release_notes': '', 'can_auto_update': False, 'features': { - 'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'], - 'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED'] + 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, + 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED } } diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 0b5c84c2a3..1056d5eb62 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,10 +1,11 @@ import datetime import pytz -from flask import current_app, request +from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse +from configs import dify_config from constants.languages import supported_language from controllers.console import api from controllers.console.setup import setup_required @@ -36,7 +37,7 @@ class AccountInitApi(Resource): parser = reqparse.RequestParser() - if current_app.config['EDITION'] == 'CLOUD': + if dify_config.EDITION == 'CLOUD': parser.add_argument('invitation_code', type=str, location='json') parser.add_argument( @@ -45,7 +46,7 @@ class AccountInitApi(Resource): required=True, location='json') args = parser.parse_args() - if current_app.config['EDITION'] == 'CLOUD': + if dify_config.EDITION == 'CLOUD': if not args['invitation_code']: raise ValueError('invitation_code is required') diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index e8c88850a4..34e9da3841 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,8 +1,8 @@ -from flask import current_app from flask_login import current_user from flask_restful import Resource, abort, marshal_with, reqparse import services +from configs import dify_config from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check @@ -48,7 +48,7 @@ class MemberInviteEmailApi(Resource): inviter = current_user invitation_results = [] - console_web_url = current_app.config.get("CONSOLE_WEB_URL") + console_web_url = dify_config.CONSOLE_WEB_URL for invitee_email in invitee_emails: try: token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) @@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 - member = Account.query.get(str(member_id)) + member = db.session.get(Account, str(member_id)) if not member: abort(404) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 6e3f78d4e2..bafeabb08a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,10 +1,11 @@ import io -from flask import current_app, send_file +from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden +from configs import dify_config from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -104,7 +105,7 @@ class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider) - icon_cache_max_age = current_app.config.get('TOOL_ICON_CACHE_MAX_AGE') + icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) class ToolApiProviderAddApi(Resource): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 7c8ad11078..3baf69acfd 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,9 +1,10 @@ import json from functools import wraps -from flask import abort, current_app, request +from flask import abort, request from flask_login import current_user +from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError from services.feature_service import FeatureService from services.operation_service import OperationService @@ -26,7 +27,7 @@ def account_initialization_required(view): def only_edition_cloud(view): @wraps(view) def decorated(*args, **kwargs): - if current_app.config['EDITION'] != 'CLOUD': + if dify_config.EDITION != 'CLOUD': abort(404) return view(*args, **kwargs) @@ -37,7 +38,7 @@ def only_edition_cloud(view): def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): - if current_app.config['EDITION'] != 'SELF_HOSTED': + if dify_config.EDITION != 'SELF_HOSTED': abort(404) return view(*args, **kwargs) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 607d71598f..3c009af343 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -76,10 +76,12 @@ class TextApi(Resource): parser = reqparse.RequestParser() parser.add_argument('message_id', type=str, required=False, location='json') parser.add_argument('voice', type=str, location='json') + parser.add_argument('text', type=str, location='json') parser.add_argument('streaming', type=bool, location='json') args = parser.parse_args() - message_id = args.get('message_id') + message_id = args.get('message_id', None) + text = args.get('text', None) if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] and app_model.workflow and app_model.workflow.features_dict): @@ -87,15 +89,15 @@ class TextApi(Resource): voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( - 'voice') + voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') except Exception: voice = None response = AudioService.transcript_tts( app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, - voice=voice + voice=voice, + text=text ) return response diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index dd11949e84..10484c9027 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import Resource, reqparse +from flask_restful import Resource, fields, marshal_with, reqparse from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -21,14 +21,43 @@ from core.errors.error import ( QuotaExceededError, ) from core.model_runtime.errors.invoke import InvokeError +from extensions.ext_database import db from libs import helper from models.model import App, AppMode, EndUser +from models.workflow import WorkflowRun from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) class WorkflowRunApi(Resource): + workflow_run_fields = { + 'id': fields.String, + 'workflow_id': fields.String, + 'status': fields.String, + 'inputs': fields.Raw, + 'outputs': fields.Raw, + 'error': fields.String, + 'total_steps': fields.Integer, + 'total_tokens': fields.Integer, + 'created_at': fields.DateTime, + 'finished_at': fields.DateTime, + 'elapsed_time': fields.Float, + } + + @validate_app_token + @marshal_with(workflow_run_fields) + def get(self, app_model: App, workflow_id: str): + """ + Get a workflow task running detail + """ + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() + return workflow_run + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): """ @@ -88,5 +117,5 @@ class WorkflowTaskStopApi(Resource): } -api.add_resource(WorkflowRunApi, '/workflows/run') +api.add_resource(WorkflowRunApi, '/workflows/run/', '/workflows/run') api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 8be872f5f9..0e905f905a 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -74,10 +74,12 @@ class TextApi(WebApiResource): parser = reqparse.RequestParser() parser.add_argument('message_id', type=str, required=False, location='json') parser.add_argument('voice', type=str, location='json') + parser.add_argument('text', type=str, location='json') parser.add_argument('streaming', type=bool, location='json') args = parser.parse_args() - message_id = args.get('message_id') + message_id = args.get('message_id', None) + text = args.get('text', None) if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] and app_model.workflow and app_model.workflow.features_dict): @@ -94,7 +96,8 @@ class TextApi(WebApiResource): app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, - voice=voice + voice=voice, + text=text ) return response diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index bec76e7a24..7019b5e39f 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -342,10 +342,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): """ tool_calls = [] for prompt_message in llm_result_chunk.delta.message.tool_calls: + args = {} + if prompt_message.function.arguments != '': + args = json.loads(prompt_message.function.arguments) + tool_calls.append(( prompt_message.id, prompt_message.function.name, - json.loads(prompt_message.function.arguments), + args, )) return tool_calls @@ -359,10 +363,14 @@ class FunctionCallAgentRunner(BaseAgentRunner): """ tool_calls = [] for prompt_message in llm_result.message.tool_calls: + args = {} + if prompt_message.function.arguments != '': + args = json.loads(prompt_message.function.arguments) + tool_calls.append(( prompt_message.id, prompt_message.function.name, - json.loads(prompt_message.function.arguments), + args, )) return tool_calls diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 353fe85b74..3dea305e98 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -1,6 +1,7 @@ -from typing import Optional, Union +from collections.abc import Mapping +from typing import Any -from core.app.app_config.entities import AppAdditionalFeatures, EasyUIBasedAppModelConfigFrom +from core.app.app_config.entities import AppAdditionalFeatures from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager @@ -10,37 +11,19 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import AppMode, AppModelConfig +from models.model import AppMode class BaseAppConfigManager: - @classmethod - def convert_to_config_dict(cls, config_from: EasyUIBasedAppModelConfigFrom, - app_model_config: Union[AppModelConfig, dict], - config_dict: Optional[dict] = None) -> dict: - """ - Convert app model config to config dict - :param config_from: app model config from - :param app_model_config: app model config - :param config_dict: app model config dict - :return: - """ - if config_from != EasyUIBasedAppModelConfigFrom.ARGS: - app_model_config_dict = app_model_config.to_dict() - config_dict = app_model_config_dict.copy() - - return config_dict - - @classmethod - def convert_features(cls, config_dict: dict, app_mode: AppMode) -> AppAdditionalFeatures: + def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> AppAdditionalFeatures: """ Convert app config to app model config :param config_dict: app config :param app_mode: app mode """ - config_dict = config_dict.copy() + config_dict = dict(config_dict.items()) additional_features = AppAdditionalFeatures() additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 2049b573cd..86799fb1ab 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,11 +1,12 @@ -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.app.app_config.entities import FileExtraConfig class FileUploadConfigManager: @classmethod - def convert(cls, config: dict, is_vision: bool = True) -> Optional[FileExtraConfig]: + def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]: """ Convert model config to model config diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index b516fa46ab..f11e268e73 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -3,13 +3,13 @@ from core.app.app_config.entities import TextToSpeechEntity class TextToSpeechConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict): """ Convert model config to model config :param config: model config args """ - text_to_speech = False + text_to_speech = None text_to_speech_dict = config.get('text_to_speech') if text_to_speech_dict: if text_to_speech_dict.get('enabled'): diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 84723cb5c7..0141dbec58 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import os import threading @@ -8,6 +9,7 @@ from typing import Union from flask import Flask, current_app from pydantic import ValidationError +import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -107,6 +109,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras=extras, trace_manager=trace_manager ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( app_model=app_model, @@ -173,6 +176,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): inputs=args['inputs'] ) ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( app_model=app_model, @@ -225,6 +229,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): 'queue_manager': queue_manager, 'conversation_id': conversation.id, 'message_id': message.id, + 'user': user, + 'context': contextvars.copy_context() }) worker_thread.start() @@ -249,7 +255,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, - message_id: str) -> None: + message_id: str, + user: Account, + context: contextvars.Context) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -259,6 +267,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param message_id: message ID :return: """ + for var, val in context.items(): + var.set(val) with flask_app.app_context(): try: runner = AdvancedChatAppRunner() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d633b30029..208736b990 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,7 +1,8 @@ import logging import os import time -from typing import Optional, cast +from collections.abc import Mapping +from typing import Any, Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback @@ -14,6 +15,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -86,7 +88,7 @@ class AdvancedChatAppRunner(AppRunner): db.session.close() - workflow_callbacks = [WorkflowEventTriggerCallback( + workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( queue_manager=queue_manager, workflow=workflow )] @@ -160,7 +162,7 @@ class AdvancedChatAppRunner(AppRunner): self, queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str ) -> bool: diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 08069332ba..ef579827b4 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -1,9 +1,11 @@ import json from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( + AppBlockingResponse, + AppStreamResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -18,12 +20,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ + blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) response = { 'event': 'message', 'task_id': blocking_response.task_id, @@ -39,7 +42,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -53,8 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -83,8 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index b332ac7af8..e5451ffb3b 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None - def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def process(self): """ Process generate task pipeline. :return: @@ -136,8 +136,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> ChatbotAppBlockingResponse: + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: """ Process blocking response. :return: @@ -167,8 +166,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc raise Exception('Queue listening stopped unexpectedly.') - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[ChatbotAppStreamResponse, None, None]: + def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: """ To stream response. :return: diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index 78fe077e6b..8d43155a08 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -14,13 +14,13 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from models.workflow import Workflow -class WorkflowEventTriggerCallback(BaseWorkflowCallback): +class WorkflowEventTriggerCallback(WorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index bacd1a5477..1165314a7f 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator -from typing import Union +from typing import Any, Union from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse @@ -15,44 +15,41 @@ class AppGenerateResponseConverter(ABC): @classmethod def convert(cls, response: Union[ AppBlockingResponse, - Generator[AppStreamResponse, None, None] - ], invoke_from: InvokeFrom) -> Union[ - dict, - Generator[str, None, None] - ]: + Generator[AppStreamResponse, Any, None] + ], invoke_from: InvokeFrom): if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - if isinstance(response, cls._blocking_response_type): + if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: - def _generate(): + def _generate_full_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_full_response(response): if chunk == 'ping': yield f'event: {chunk}\n\n' else: yield f'data: {chunk}\n\n' - return _generate() + return _generate_full_response() else: - if isinstance(response, cls._blocking_response_type): + if isinstance(response, AppBlockingResponse): return cls.convert_blocking_simple_response(response) else: - def _generate(): + def _generate_simple_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_simple_response(response): if chunk == 'ping': yield f'event: {chunk}\n\n' else: yield f'data: {chunk}\n\n' - return _generate() + return _generate_simple_response() @classmethod @abstractmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @abstractmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @@ -68,7 +65,7 @@ class AppGenerateResponseConverter(ABC): raise NotImplementedError @classmethod - def _get_simple_metadata(cls, metadata: dict) -> dict: + def _get_simple_metadata(cls, metadata: dict[str, Any]): """ Get simple metadata. :param metadata: metadata diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 0f547ca164..b1986dbcee 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import os import threading @@ -8,6 +9,7 @@ from typing import Union from flask import Flask, current_app from pydantic import ValidationError +import contexts from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom @@ -38,7 +40,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, - ) -> Union[dict, Generator[dict, None, None]]: + ): """ Generate App response. @@ -86,6 +88,7 @@ class WorkflowAppGenerator(BaseAppGenerator): call_depth=call_depth, trace_manager=trace_manager ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( app_model=app_model, @@ -126,7 +129,8 @@ class WorkflowAppGenerator(BaseAppGenerator): worker_thread = threading.Thread(target=self._generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager + 'queue_manager': queue_manager, + 'context': contextvars.copy_context() }) worker_thread.start() @@ -150,8 +154,7 @@ class WorkflowAppGenerator(BaseAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): """ Generate App response. @@ -193,6 +196,7 @@ class WorkflowAppGenerator(BaseAppGenerator): inputs=args['inputs'] ) ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) return self._generate( app_model=app_model, @@ -205,7 +209,8 @@ class WorkflowAppGenerator(BaseAppGenerator): def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager) -> None: + queue_manager: AppQueueManager, + context: contextvars.Context) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -213,6 +218,8 @@ class WorkflowAppGenerator(BaseAppGenerator): :param queue_manager: queue manager :return: """ + for var, val in context.items(): + var.set(val) with flask_app.app_context(): try: # workflow app diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b3cd517fc6..4cb027fa0a 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -56,7 +57,7 @@ class WorkflowAppRunner: db.session.close() - workflow_callbacks = [WorkflowEventTriggerCallback( + workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( queue_manager=queue_manager, workflow=workflow )] diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index e423a40bcb..4472a7e9b5 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -14,13 +14,13 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType from models.workflow import Workflow -class WorkflowEventTriggerCallback(BaseWorkflowCallback): +class WorkflowEventTriggerCallback(WorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index f617c671e9..2e6431d6d0 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -2,7 +2,7 @@ from typing import Optional from core.app.entities.queue_entities import AppQueueEvent from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType @@ -15,7 +15,7 @@ _TEXT_COLOR_MAPPING = { } -class WorkflowLoggingCallback(BaseWorkflowCallback): +class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: self.current_node_id = None diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1d2ad4a373..9a861c29e2 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from enum import Enum from typing import Any, Optional @@ -76,7 +77,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: AppConfig - inputs: dict[str, Any] + inputs: Mapping[str, Any] files: list[FileVar] = [] user_id: str @@ -140,7 +141,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): app_config: WorkflowUIBasedAppConfig conversation_id: Optional[str] = None - query: Optional[str] = None + query: str class SingleIterationRunEntity(BaseModel): """ diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py new file mode 100644 index 0000000000..e5cecd35fd --- /dev/null +++ b/api/core/app/segments/__init__.py @@ -0,0 +1,27 @@ +from .segment_group import SegmentGroup +from .segments import Segment +from .types import SegmentType +from .variables import ( + ArrayVariable, + FileVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, +) + +__all__ = [ + 'IntegerVariable', + 'FloatVariable', + 'ObjectVariable', + 'SecretVariable', + 'FileVariable', + 'StringVariable', + 'ArrayVariable', + 'Variable', + 'SegmentType', + 'SegmentGroup', + 'Segment' +] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py new file mode 100644 index 0000000000..4f0b361d95 --- /dev/null +++ b/api/core/app/segments/factory.py @@ -0,0 +1,64 @@ +from collections.abc import Mapping +from typing import Any + +from core.file.file_obj import FileVar + +from .segments import Segment, StringSegment +from .types import SegmentType +from .variables import ( + ArrayVariable, + FileVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, +) + + +def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable: + if (value_type := m.get('value_type')) is None: + raise ValueError('missing value type') + if not m.get('name'): + raise ValueError('missing name') + if (value := m.get('value')) is None: + raise ValueError('missing value') + match value_type: + case SegmentType.STRING: + return StringVariable.model_validate(m) + case SegmentType.NUMBER if isinstance(value, int): + return IntegerVariable.model_validate(m) + case SegmentType.NUMBER if isinstance(value, float): + return FloatVariable.model_validate(m) + case SegmentType.SECRET: + return SecretVariable.model_validate(m) + case SegmentType.NUMBER if not isinstance(value, float | int): + raise ValueError(f'invalid number value {value}') + raise ValueError(f'not supported value type {value_type}') + + +def build_anonymous_variable(value: Any, /) -> Variable: + if isinstance(value, str): + return StringVariable(name='anonymous', value=value) + if isinstance(value, int): + return IntegerVariable(name='anonymous', value=value) + if isinstance(value, float): + return FloatVariable(name='anonymous', value=value) + if isinstance(value, dict): + # TODO: Limit the depth of the object + obj = {k: build_anonymous_variable(v) for k, v in value.items()} + return ObjectVariable(name='anonymous', value=obj) + if isinstance(value, list): + # TODO: Limit the depth of the array + elements = [build_anonymous_variable(v) for v in value] + return ArrayVariable(name='anonymous', value=elements) + if isinstance(value, FileVar): + return FileVariable(name='anonymous', value=value) + raise ValueError(f'not supported value {value}') + + +def build_segment(value: Any, /) -> Segment: + if isinstance(value, str): + return StringSegment(value=value) + raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py new file mode 100644 index 0000000000..21d1b89541 --- /dev/null +++ b/api/core/app/segments/parser.py @@ -0,0 +1,17 @@ +import re + +from core.app.segments import SegmentGroup, factory +from core.workflow.entities.variable_pool import VariablePool + +VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') + + +def convert_template(*, template: str, variable_pool: VariablePool): + parts = re.split(VARIABLE_PATTERN, template) + segments = [] + for part in parts: + if '.' in part and (value := variable_pool.get(part.split('.'))): + segments.append(value) + else: + segments.append(factory.build_segment(part)) + return SegmentGroup(segments=segments) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py new file mode 100644 index 0000000000..0d5176b885 --- /dev/null +++ b/api/core/app/segments/segment_group.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +from .segments import Segment + + +class SegmentGroup(BaseModel): + segments: list[Segment] + + @property + def text(self): + return ''.join([segment.text for segment in self.segments]) + + @property + def log(self): + return ''.join([segment.log for segment in self.segments]) + + @property + def markdown(self): + return ''.join([segment.markdown for segment in self.segments]) \ No newline at end of file diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py new file mode 100644 index 0000000000..a6e953829e --- /dev/null +++ b/api/core/app/segments/segments.py @@ -0,0 +1,39 @@ +from typing import Any + +from pydantic import BaseModel, ConfigDict, field_validator + +from .types import SegmentType + + +class Segment(BaseModel): + model_config = ConfigDict(frozen=True) + + value_type: SegmentType + value: Any + + @field_validator('value_type') + def validate_value_type(cls, value): + """ + This validator checks if the provided value is equal to the default value of the 'value_type' field. + If the value is different, a ValueError is raised. + """ + if value != cls.model_fields['value_type'].default: + raise ValueError("Cannot modify 'value_type'") + return value + + @property + def text(self) -> str: + return str(self.value) + + @property + def log(self) -> str: + return str(self.value) + + @property + def markdown(self) -> str: + return str(self.value) + + +class StringSegment(Segment): + value_type: SegmentType = SegmentType.STRING + value: str diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py new file mode 100644 index 0000000000..517f210533 --- /dev/null +++ b/api/core/app/segments/types.py @@ -0,0 +1,17 @@ +from enum import Enum + + +class SegmentType(str, Enum): + STRING = 'string' + NUMBER = 'number' + FILE = 'file' + + SECRET = 'secret' + + OBJECT = 'object' + + ARRAY = 'array' + ARRAY_STRING = 'array[string]' + ARRAY_NUMBER = 'array[number]' + ARRAY_OBJECT = 'array[object]' + ARRAY_FILE = 'array[file]' \ No newline at end of file diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py new file mode 100644 index 0000000000..e600b442d6 --- /dev/null +++ b/api/core/app/segments/variables.py @@ -0,0 +1,83 @@ +import json +from collections.abc import Mapping, Sequence + +from pydantic import Field + +from core.file.file_obj import FileVar +from core.helper import encrypter + +from .segments import Segment, StringSegment +from .types import SegmentType + + +class Variable(Segment): + """ + A variable is a segment that has a name. + """ + + id: str = Field( + default='', + description="Unique identity for variable. It's only used by environment variables now.", + ) + name: str + + +class StringVariable(StringSegment, Variable): + pass + + +class FloatVariable(Variable): + value_type: SegmentType = SegmentType.NUMBER + value: float + + +class IntegerVariable(Variable): + value_type: SegmentType = SegmentType.NUMBER + value: int + + +class ObjectVariable(Variable): + value_type: SegmentType = SegmentType.OBJECT + value: Mapping[str, Variable] + + @property + def text(self) -> str: + # TODO: Process variables. + return json.dumps(self.model_dump()['value'], ensure_ascii=False) + + @property + def log(self) -> str: + # TODO: Process variables. + return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + + @property + def markdown(self) -> str: + # TODO: Use markdown code block + return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + + +class ArrayVariable(Variable): + value_type: SegmentType = SegmentType.ARRAY + value: Sequence[Variable] + + @property + def markdown(self) -> str: + return '\n'.join(['- ' + item.markdown for item in self.value]) + + +class FileVariable(Variable): + value_type: SegmentType = SegmentType.FILE + # TODO: embed FileVar in this model. + value: FileVar + + @property + def markdown(self) -> str: + return self.value.to_markdown() + + +class SecretVariable(StringVariable): + value_type: SegmentType = SegmentType.SECRET + + @property + def log(self) -> str: + return encrypter.obfuscated_token(self.value) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index f973b7e1ce..03f8244bab 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,9 +1,11 @@ import os +from collections.abc import Mapping, Sequence from typing import Any, Optional, TextIO, Union from pydantic import BaseModel from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.tools.entities.tool_entities import ToolInvokeMessage _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -43,7 +45,7 @@ class DifyAgentCallbackHandler(BaseModel): def on_tool_start( self, tool_name: str, - tool_inputs: dict[str, Any], + tool_inputs: Mapping[str, Any], ) -> None: """Do nothing.""" print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) @@ -51,8 +53,8 @@ class DifyAgentCallbackHandler(BaseModel): def on_tool_end( self, tool_name: str, - tool_inputs: dict[str, Any], - tool_outputs: str, + tool_inputs: Mapping[str, Any], + tool_outputs: Sequence[ToolInvokeMessage], message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 842b539ad1..7b2f8217f9 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,4 +1,5 @@ -from typing import Union +from collections.abc import Mapping, Sequence +from typing import Any, Union import requests @@ -16,7 +17,7 @@ class MessageFileParser: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig, + def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]) -> list[FileVar]: """ validate and transform files arg diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index f094f7d79b..5b69d3af4b 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY -CODE_EXECUTION_TIMEOUT= (10, 60) +CODE_EXECUTION_TIMEOUT = (10, 60) class CodeExecutionException(Exception): pass @@ -64,7 +64,7 @@ class CodeExecutor: @classmethod def execute_code(cls, - language: Literal['python3', 'javascript', 'jinja2'], + language: CodeLanguage, preload: str, code: str, dependencies: Optional[list[CodeDependency]] = None) -> str: @@ -119,7 +119,7 @@ class CodeExecutor: return response.data.stdout @classmethod - def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: """ Execute code :param language: code language diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index fcf293dc1c..bf87a842c0 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -6,11 +6,16 @@ from models.account import Tenant def obfuscated_token(token: str): - return token[:6] + '*' * (len(token) - 8) + token[-2:] + if not token: + return token + if len(token) <= 8: + return '*' * 20 + return token[:6] + '*' * 12 + token[-2:] def encrypt_token(tenant_id: str, token: str): - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): + raise ValueError(f'Tenant with id {tenant_id} not found') encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index e4ceeb652e..04675d85bb 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -14,6 +14,9 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> :return: a dict with name as key and index as value """ position_file_name = os.path.join(folder_path, file_name) + if not position_file_name or not os.path.exists(position_file_name): + return {} + positions = load_yaml_file(position_file_name, ignore_error=True) position_map = {} index = 0 diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 170a28432b..a1737f00c6 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -64,6 +64,7 @@ User Input: SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keeping each question under 20 characters.\n" + "MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" "The output must be an array in JSON format following the specified schema:\n" "[\"question1\",\"question2\",\"question3\"]\n" ) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 21f1965e93..b33d4dd7cb 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -103,7 +103,7 @@ class TokenBufferMemory: if curr_message_tokens > max_token_limit: pruned_memory = [] - while curr_message_tokens > max_token_limit and prompt_messages: + while curr_message_tokens > max_token_limit and len(prompt_messages)>1: pruned_memory.append(prompt_messages.pop(0)) curr_message_tokens = self.model_instance.get_llm_num_tokens( prompt_messages diff --git a/api/core/model_manager.py b/api/core/model_manager.py index d64db890f9..dc7556f09a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -413,6 +413,7 @@ class LBModelManager: for load_balancing_config in self._load_balancing_configs: if load_balancing_config.name == "__inherit__": if not managed_credentials: + # FIXME: Mutation to loop iterable `self._load_balancing_configs` during iteration # remove __inherit__ if managed credentials is not provided self._load_balancing_configs.remove(load_balancing_config) else: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml index 72d4d8545b..e02c5517fe 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml @@ -27,9 +27,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens required: true - default: 4096 + default: 8192 min: 1 - max: 4096 + max: 8192 - name: response_format use_template: response_format pricing: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index fbc0b722b1..107efe4867 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -113,6 +113,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if system: extra_model_kwargs['system'] = system + # Add the new header for claude-3-5-sonnet-20240620 model + extra_headers = {} + if model == "claude-3-5-sonnet-20240620": + extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" + if tools: extra_model_kwargs['tools'] = [ self._transform_tool_prompt(tool) for tool in tools @@ -121,6 +126,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): model=model, messages=prompt_message_dicts, stream=stream, + extra_headers=extra_headers, **model_parameters, **extra_model_kwargs ) @@ -130,6 +136,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): model=model, messages=prompt_message_dicts, stream=stream, + extra_headers=extra_headers, **model_parameters, **extra_model_kwargs ) @@ -138,7 +145,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 25bc94cde6..34d1f64210 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -501,7 +501,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} elif isinstance(message, AssistantPromptMessage): - message = cast(AssistantPromptMessage, message) + # message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 566055e3f7..91b9215829 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,6 +1,8 @@ - gpt-4 - gpt-4o - gpt-4o-2024-05-13 +- gpt-4o-mini +- gpt-4o-mini-2024-07-18 - gpt-4-turbo - gpt-4-turbo-2024-04-09 - gpt-4-turbo-preview diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml new file mode 100644 index 0000000000..6f23e0647d --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-mini-2024-07-18 +label: + zh_Hans: gpt-4o-mini-2024-07-18 + en_US: gpt-4o-mini-2024-07-18 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '0.15' + output: '0.60' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml new file mode 100644 index 0000000000..b97fbf8aab --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-mini +label: + zh_Hans: gpt-4o-mini + en_US: gpt-4o-mini +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '0.15' + output: '0.60' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml b/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml index 51131249e5..fd4ed1109d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml @@ -1,4 +1,5 @@ - openai/gpt-4o +- openai/gpt-4o-mini - openai/gpt-4 - openai/gpt-4-32k - openai/gpt-3.5-turbo diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml new file mode 100644 index 0000000000..de0bad4136 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml @@ -0,0 +1,43 @@ +model: openai/gpt-4o-mini +label: + en_US: gpt-4o-mini +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: "0.15" + output: "0.60" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/sagemaker/__init__.py b/api/core/model_runtime/model_providers/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png new file mode 100644 index 0000000000..0abe07a78f Binary files /dev/null and b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png new file mode 100644 index 0000000000..6b88942a5c Binary files /dev/null and b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py b/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py new file mode 100644 index 0000000000..f8e7757a96 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -0,0 +1,238 @@ +import json +import logging +from collections.abc import Generator +from typing import Any, Optional, Union + +import boto3 + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class SageMakerLargeLanguageModel(LargeLanguageModel): + """ + Model class for Cohere large language model. + """ + sagemaker_client: Any = None + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # get model mode + model_mode = self.get_model_mode(model, credentials) + + if not self.sagemaker_client: + access_key = credentials.get('access_key') + secret_key = credentials.get('secret_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, + Body=json.dumps( + { + "inputs": prompt_messages[0].content, + "parameters": { "stop" : stop}, + "history" : [] + } + ), + ContentType="application/json", + ) + + assistant_text = response_model['Body'].read().decode('utf8') + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + usage = self._calc_response_usage(model, credentials, 0, 0) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # get model mode + model_mode = self.get_model_mode(model) + + try: + return 0 + except Exception as e: + raise self._transform_invoke_error(e) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # get model mode + model_mode = self.get_model_mode(model) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=credentials.get('context_length', 2048), + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = LLMMode.value_of(credentials["mode"]) + + if completion_type == LLMMode.CHAT: + print(f"completion_type : {LLMMode.CHAT.value}") + + if completion_type == LLMMode.COMPLETION: + print(f"completion_type : {LLMMode.COMPLETION.value}") + + features = [] + + support_function_call = credentials.get('support_function_call', False) + if support_function_call: + features.append(ModelFeature.TOOL_CALL) + + support_vision = credentials.get('support_vision', False) + if support_vision: + features.append(ModelFeature.VISION) + + context_length = credentials.get('context_length', 2048) + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=features, + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length + }, + parameter_rules=rules + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py b/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py new file mode 100644 index 0000000000..0b06f54ef1 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -0,0 +1,190 @@ +import json +import logging +from typing import Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + +class SageMakerRerankModel(RerankModel): + """ + Model class for Cohere rerank model. + """ + sagemaker_client: Any = None + + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): + inputs = [query_input]*len(docs) + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=rerank_endpoint, + Body=json.dumps( + { + "inputs": inputs, + "docs": docs + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + scores = json_obj['scores'] + return scores if isinstance(scores, list) else [scores] + + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + line = 0 + try: + if len(docs) == 0: + return RerankResult( + model=model, + docs=docs + ) + + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + candidate_docs = [] + + scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) + for idx in range(len(scores)): + candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + + sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + + line = 3 + rerank_documents = [] + for idx, result in enumerate(candidate_docs): + rerank_document = RerankDocument( + index=idx, + text=result.get('content'), + score=result.get('score', -100.0) + ) + + if score_threshold is not None: + if rerank_document.score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult( + model=model, + docs=rerank_documents + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py new file mode 100644 index 0000000000..02d05f406c --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class SageMakerProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml new file mode 100644 index 0000000000..290cb0edab --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -0,0 +1,125 @@ +provider: sagemaker +label: + zh_Hans: Sagemaker + en_US: Sagemaker +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Customized model on Sagemaker + zh_Hans: Sagemaker上的私有化部署的模型 +background: "#ECE9E3" +help: + title: + en_US: How to deploy customized model on Sagemaker + zh_Hans: 如何在Sagemaker上的私有化部署的模型 + url: + en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint + zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9 +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: sagemaker_endpoint + label: + en_US: sagemaker endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输出你的Sagemaker推理端点 + en_US: Enter your Sagemaker Inference endpoint + - variable: aws_access_key_id + required: false + label: + en_US: Access Key (If not provided, credentials are obtained from the running environment.) + zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。) + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 在此输入您的 Access Key + - variable: aws_secret_access_key + required: false + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 在此输入您的 Secret Access Key + - variable: aws_region + required: false + label: + en_US: AWS Region + zh_Hans: AWS 地区 + type: select + default: us-east-1 + options: + - value: us-east-1 + label: + en_US: US East (N. Virginia) + zh_Hans: 美国东部 (弗吉尼亚北部) + - value: us-west-2 + label: + en_US: US West (Oregon) + zh_Hans: 美国西部 (俄勒冈州) + - value: ap-southeast-1 + label: + en_US: Asia Pacific (Singapore) + zh_Hans: 亚太地区 (新加坡) + - value: ap-northeast-1 + label: + en_US: Asia Pacific (Tokyo) + zh_Hans: 亚太地区 (东京) + - value: eu-central-1 + label: + en_US: Europe (Frankfurt) + zh_Hans: 欧洲 (法兰克福) + - value: us-gov-west-1 + label: + en_US: AWS GovCloud (US-West) + zh_Hans: AWS GovCloud (US-West) + - value: ap-southeast-2 + label: + en_US: Asia Pacific (Sydney) + zh_Hans: 亚太地区 (悉尼) + - value: cn-north-1 + label: + en_US: AWS Beijing (cn-north-1) + zh_Hans: 中国北京 (cn-north-1) + - value: cn-northwest-1 + label: + en_US: AWS Ningxia (cn-northwest-1) + zh_Hans: 中国宁夏 (cn-northwest-1) diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py new file mode 100644 index 0000000000..4b2858b1a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -0,0 +1,214 @@ +import itertools +import json +import logging +import time +from typing import Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +BATCH_SIZE = 20 +CONTEXT_SIZE=8192 + +logger = logging.getLogger(__name__) + +def batch_generator(generator, batch_size): + while True: + batch = list(itertools.islice(generator, batch_size)) + if not batch: + break + yield batch + +class SageMakerEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + sagemaker_client: Any = None + + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + response_model = sm_client.invoke_endpoint( + EndpointName=endpoint_name, + Body=json.dumps( + { + "inputs": content_list, + "parameters": {}, + "is_query" : False, + "instruction" : '' + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + embeddings = json_obj['embeddings'] + return embeddings + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + # get model properties + try: + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + sagemaker_endpoint = credentials.get('sagemaker_endpoint') + + line = 3 + truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + + batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) + all_embeddings = [] + + line = 4 + for batch in batches: + embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch) + all_embeddings.extend(embeddings) + + line = 5 + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=0 # It's not SAAS API, usage is meaningless + ) + line = 6 + + return TextEmbeddingResult( + embeddings=all_embeddings, + usage=usage, + model=model + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + print("validate_credentials ok....") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + KeyError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, + ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, + }, + parameter_rules=[] + ) + + return entity diff --git a/api/core/model_runtime/model_providers/stepfun/__init__.py b/api/core/model_runtime/model_providers/stepfun/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png new file mode 100644 index 0000000000..c118ea09bd Binary files /dev/null and b/api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/stepfun/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/stepfun/_assets/icon_s_en.png new file mode 100644 index 0000000000..85b96d0c74 Binary files /dev/null and b/api/core/model_runtime/model_providers/stepfun/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/stepfun/llm/__init__.py b/api/core/model_runtime/model_providers/stepfun/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml b/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml new file mode 100644 index 0000000000..b34433e1d4 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/_position.yaml @@ -0,0 +1,6 @@ +- step-1-8k +- step-1-32k +- step-1-128k +- step-1-256k +- step-1v-8k +- step-1v-32k diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py new file mode 100644 index 0000000000..6f6ffc8faa --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -0,0 +1,328 @@ +import json +from collections.abc import Generator +from typing import Optional, Union, cast + +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + self._add_function_call(model, credentials) + user = user[:32] if user else None + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get('function_calling_type') == 'tool_call' + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name='temperature', + use_template='temperature', + label=I18nObject(en_US='Temperature', zh_Hans='温度'), + type=ParameterType.FLOAT, + ), + ParameterRule( + name='max_tokens', + use_template='max_tokens', + default=512, + min=1, + max=int(credentials.get('max_tokens', 1024)), + label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + type=ParameterType.INT, + ), + ParameterRule( + name='top_p', + use_template='top_p', + label=I18nObject(en_US='Top P', zh_Hans='Top P'), + type=ParameterType.FLOAT, + ), + ] + ) + + def _add_custom_parameters(self, credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + + def _add_function_call(self, model: str, credentials: dict) -> None: + model_schema = self.get_model_schema(model, credentials) + if model_schema and { + ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL + }.intersection(model_schema.features or []): + credentials['function_calling_type'] = 'tool_call' + + def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + """ + Convert PromptMessage to dict for OpenAI API format + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": { + "url": message_content.data, + } + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [] + for function_call in message.tool_calls: + message_dict["tool_calls"].append({ + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments + } + }) + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", + arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call["id"] if response_tool_call.get("id") else "", + type=response_tool_call["type"] if response_tool_call.get("type") else "", + function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, + prompt_messages: list[PromptMessage]) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: model credentials + :param response: streamed response + :param prompt_messages: prompt messages + :return: llm response chunk generator + """ + full_assistant_content = '' + chunk_index = 0 + + def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ + -> LLMResultChunk: + # calculate num tokens + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=message, + finish_reason=finish_reason, + usage=usage + ) + ) + + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + finish_reason = "Unknown" + + def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): + def get_tool_call(tool_name: str): + if not tool_name: + return tools_calls[-1] + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id='', + type='', + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + ) + tools_calls.append(tool_call) + + return tool_call + + for new_tool_call in new_tool_calls: + # get tool call + tool_call = get_tool_call(new_tool_call.function.name) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): + if chunk: + # ignore sse comments + if chunk.startswith(':'): + continue + decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + chunk_json = None + try: + chunk_json = json.loads(decoded_chunk) + # stream ended + except json.JSONDecodeError as e: + yield create_final_llm_result_chunk( + index=chunk_index + 1, + message=AssistantPromptMessage(content=""), + finish_reason="Non-JSON encountered." + ) + break + if not chunk_json or len(chunk_json['choices']) == 0: + continue + + choice = chunk_json['choices'][0] + finish_reason = chunk_json['choices'][0].get('finish_reason') + chunk_index += 1 + + if 'delta' in choice: + delta = choice['delta'] + delta_content = delta.get('content') + + assistant_message_tool_calls = delta.get('tool_calls', None) + # assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if assistant_message_tool_calls: + tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + increase_tool_call(tool_calls) + + if delta_content is None or delta_content == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta_content, + tool_calls=tool_calls if assistant_message_tool_calls else [] + ) + + full_assistant_content += delta_content + elif 'text' in choice: + choice_text = choice.get('text', '') + if choice_text == '': + continue + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=choice_text) + full_assistant_content += choice_text + else: + continue + + # check payload indicator for completion + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=assistant_prompt_message, + ) + ) + + chunk_index += 1 + + if tools_calls: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=chunk_index, + message=AssistantPromptMessage( + tool_calls=tools_calls, + content="" + ), + ) + ) + + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=""), + finish_reason=finish_reason + ) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml new file mode 100644 index 0000000000..13f7b7fd26 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml @@ -0,0 +1,25 @@ +model: step-1-128k +label: + zh_Hans: step-1-128k + en_US: step-1-128k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 128000 +pricing: + input: '0.04' + output: '0.20' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml new file mode 100644 index 0000000000..f80ec9851c --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml @@ -0,0 +1,25 @@ +model: step-1-256k +label: + zh_Hans: step-1-256k + en_US: step-1-256k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 256000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 256000 +pricing: + input: '0.095' + output: '0.300' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml new file mode 100644 index 0000000000..96132d14a8 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml @@ -0,0 +1,28 @@ +model: step-1-32k +label: + zh_Hans: step-1-32k + en_US: step-1-32k +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 +pricing: + input: '0.015' + output: '0.070' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml new file mode 100644 index 0000000000..4a4ba8d178 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml @@ -0,0 +1,28 @@ +model: step-1-8k +label: + zh_Hans: step-1-8k + en_US: step-1-8k +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8000 +pricing: + input: '0.005' + output: '0.020' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml new file mode 100644 index 0000000000..f878ee3e56 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml @@ -0,0 +1,25 @@ +model: step-1v-32k +label: + zh_Hans: step-1v-32k + en_US: step-1v-32k +model_type: llm +features: + - vision +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 +pricing: + input: '0.015' + output: '0.070' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml new file mode 100644 index 0000000000..6c3cb61d2c --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml @@ -0,0 +1,25 @@ +model: step-1v-8k +label: + zh_Hans: step-1v-8k + en_US: step-1v-8k +model_type: llm +features: + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.005' + output: '0.020' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py new file mode 100644 index 0000000000..50b17392b5 --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -0,0 +1,30 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class StepfunProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + model_instance.validate_credentials( + model='step-1-8k', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.yaml b/api/core/model_runtime/model_providers/stepfun/stepfun.yaml new file mode 100644 index 0000000000..ccc8455adc --- /dev/null +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.yaml @@ -0,0 +1,81 @@ +provider: stepfun +label: + zh_Hans: 阶跃星辰 + en_US: Stepfun +description: + en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k + zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。 +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +background: "#FFFFFF" +help: + title: + en_US: Get your API Key from stepfun + zh_Hans: 从 stepfun 获取 API Key + url: + en_US: https://platform.stepfun.com/interface-key +supported_model_types: + - llm +configurate_methods: + - predefined-model + - customizable-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '8192' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '8192' + type: text-input + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not supported + zh_Hans: 不支持 + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-0205.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-0205.yaml index 34f73dccbb..b308abcb32 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-0205.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-3.5-8k-0205.yaml @@ -35,3 +35,4 @@ parameter_rules: zh_Hans: 禁用模型自行进行外部搜索。 en_US: Disable the model to perform external search. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml index 50c82564f1..d23ae0dc48 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-8k-latest.yaml @@ -1,4 +1,4 @@ -model: ernie-4.0-8k-Latest +model: ernie-4.0-8k-latest label: en_US: Ernie-4.0-8K-Latest model_type: llm diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k-preview b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k-preview.yaml similarity index 100% rename from api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k-preview rename to api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k-preview.yaml diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k.yaml new file mode 100644 index 0000000000..2887a510d0 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-8k.yaml @@ -0,0 +1,40 @@ +model: ernie-4.0-turbo-8k +label: + en_US: Ernie-4.0-turbo-8K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 2 + max: 2048 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + use_template: response_format + - name: disable_search + label: + zh_Hans: 禁用搜索 + en_US: Disable Search + type: boolean + help: + zh_Hans: 禁用模型自行进行外部搜索。 + en_US: Disable the model to perform external search. + required: false diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml index 52e1dc832d..74451ff9e3 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k-0321.yaml @@ -28,3 +28,4 @@ parameter_rules: default: 1.0 min: 1.0 max: 2.0 +deprecated: true diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k.yaml new file mode 100644 index 0000000000..4b11b3e895 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-character-8k.yaml @@ -0,0 +1,30 @@ +model: ernie-character-8k-0321 +label: + en_US: ERNIE-Character-8K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.95 + - name: top_p + use_template: top_p + min: 0 + max: 1.0 + default: 0.7 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 2 + max: 1024 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml index 78325c1d64..97ecb03f87 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0308.yaml @@ -28,3 +28,4 @@ parameter_rules: default: 1.0 min: 1.0 max: 2.0 +deprecated: true diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0922.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0922.yaml index ebb47417cc..7410ce51df 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0922.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-8k-0922.yaml @@ -28,3 +28,4 @@ parameter_rules: default: 1.0 min: 1.0 max: 2.0 +deprecated: true diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 9aeab04cd2..bc7f29cf6e 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -97,6 +97,7 @@ class BaiduAccessToken: baidu_access_tokens_lock.release() return token + class ErnieMessage: class Role(Enum): USER = 'user' @@ -137,7 +138,9 @@ class ErnieBotModel: 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', + 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', 'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', } @@ -149,7 +152,8 @@ class ErnieBotModel: 'ernie-3.5-8k-1222', 'ernie-3.5-4k-0205', 'ernie-3.5-128k', - 'ernie-4.0-8k' + 'ernie-4.0-8k', + 'ernie-4.0-turbo-8k', 'ernie-4.0-turbo-8k-preview' ] diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 0ef63f8e23..988bb0ce44 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): if credentials['server_url'].endswith('/'): credentials['server_url'] = credentials['server_url'][:-1] + api_key = credentials.get('api_key') or "abc" + client = OpenAI( base_url=f'{credentials["server_url"]}/v1', - api_key='abc', + api_key=api_key, max_retries=3, timeout=60, ) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 17b85862c9..649898f47a 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel): docs=[] ) - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + api_key = credentials.get('api_key') + if server_url.endswith('/'): + server_url = server_url[:-1] + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + + try: + handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) + response = handle.rerank( + documents=docs, + query=query, + top_n=top_n, + ) + except RuntimeError as e: + raise InvokeServerUnavailableError(str(e)) - handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={}) - response = handle.rerank( - documents=docs, - query=query, - top_n=top_n, - ) rerank_documents = [] for idx, result in enumerate(response['results']): @@ -102,7 +110,7 @@ class XinferenceRerankModel(RerankModel): if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( 'please check model type, the model you want to invoke is not a rerank model') - + self.invoke( model=model, credentials=credentials, diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index f60d8d3443..9ee3621317 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel): } def _speech2text_invoke( - self, - model: str, - credentials: dict, + self, + model: str, + credentials: dict, file: IO[bytes], language: Optional[str] = None, prompt: Optional[str] = None, @@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel): :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + api_key = credentials.get('api_key') + if server_url.endswith('/'): + server_url = server_url[:-1] + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} - handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={}) - response = handle.transcriptions( - audio=file, - language = language, - prompt = prompt, - response_format = response_format, - temperature = temperature - ) + try: + handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) + response = handle.transcriptions( + audio=file, + language=language, + prompt=prompt, + response_format=response_format, + temperature=temperature + ) + except RuntimeError as e: + raise InvokeServerUnavailableError(str(e)) return response["text"] diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e8429cecd4..11f1e29cb3 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ server_url = credentials['server_url'] model_uid = credentials['model_uid'] - + api_key = credentials.get('api_key') if server_url.endswith('/'): server_url = server_url[:-1] + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} try: - handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={}) + handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) embeddings = handle.create_embedding(input=texts) except RuntimeError as e: - raise InvokeServerUnavailableError(e) - + raise InvokeServerUnavailableError(str(e)) + """ for convenience, the response json is like: class Embedding(TypedDict): @@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): try: if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - + server_url = credentials['server_url'] model_uid = credentials['model_uid'] extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) @@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): server_url = server_url[:-1] client = Client(base_url=server_url) - + try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: @@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): KeyError ] } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ used to define customizable model schema """ - + entity = AIModelEntity( model=model, label=I18nObject( diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index 28ffc0389e..9496c66fdd 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -46,3 +46,12 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的Model UID en_US: Enter the model uid + - variable: api_key + label: + zh_Hans: API密钥 + en_US: API key + type: text-input + required: false + placeholder: + zh_Hans: 在此输入您的API密钥 + en_US: Enter the api key diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index 02838cb1bd..67bc6df6fd 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -42,6 +42,7 @@ class BaseKeyword(ABC): doc_id = text.metadata['doc_id'] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: + # FIXME: Mutation to loop iterable `texts` during iteration texts.remove(text) return texts diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 623b7a3123..8814c61433 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -11,7 +11,7 @@ from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -86,7 +86,7 @@ class RetrievalService: exception_message = ';\n'.join(exceptions) raise Exception(exception_message) - if retrival_method == RetrievalMethod.HYBRID_SEARCH: + if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents = data_post_processor.invoke( query=query, @@ -142,7 +142,7 @@ class RetrievalService: ) if documents: - if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH: + if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents.extend(data_post_processor.invoke( query=query, @@ -174,7 +174,7 @@ class RetrievalService: top_k=top_k ) if documents: - if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH: + if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents.extend(data_post_processor.invoke( query=query, diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index d7a5dd5dcc..442d71293f 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -7,8 +7,8 @@ _import_err_msg = ( "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" ) -from flask import current_app +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel): "region_id": self.region_id, "read_timeout": self.read_timeout, } - + class AnalyticdbVector(BaseVector): _instance = None _init = False @@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance - + def __init__(self, collection_name: str, config: AnalyticdbConfig): # collection_name must be updated every time self._collection_name = collection_name.lower() @@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector): raise ValueError( f"failed to create namespace {self.config.namespace}: {e}" ) - + def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException @@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector): def get_type(self) -> str: return VectorType.ANALYTICDB - + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection_if_not_exists(dimension) @@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector): ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 - + def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models ids_str = ",".join(f"'{id}'" for id in ids) @@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector): ) documents.append(doc) return documents - + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models score_threshold = ( @@ -291,17 +291,20 @@ class AnalyticdbVector(BaseVector): ) documents.append(doc) return documents - + def delete(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - request = gpdb_20160503_models.DeleteCollectionRequest( - collection=self._collection_name, - dbinstance_id=self.config.instance_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - region_id=self.config.region_id, - ) - self._client.delete_collection(request) + try: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionRequest( + collection=self._collection_name, + dbinstance_id=self.config.instance_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + region_id=self.config.region_id, + ) + self._client.delete_collection(request) + except Exception as e: + raise e class AnalyticdbVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): @@ -316,17 +319,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) ) - config = current_app.config + + # TODO handle optional params return AnalyticdbVector( collection_name, AnalyticdbConfig( - access_key_id=config.get("ANALYTICDB_KEY_ID"), - access_key_secret=config.get("ANALYTICDB_KEY_SECRET"), - region_id=config.get("ANALYTICDB_REGION_ID"), - instance_id=config.get("ANALYTICDB_INSTANCE_ID"), - account=config.get("ANALYTICDB_ACCOUNT"), - account_password=config.get("ANALYTICDB_PASSWORD"), - namespace=config.get("ANALYTICDB_NAMESPACE"), - namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"), + access_key_id=dify_config.ANALYTICDB_KEY_ID, + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, + region_id=dify_config.ANALYTICDB_REGION_ID, + instance_id=dify_config.ANALYTICDB_INSTANCE_ID, + account=dify_config.ANALYTICDB_ACCOUNT, + account_password=dify_config.ANALYTICDB_PASSWORD, + namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ), - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 2d4e1975ea..3629887b44 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -3,9 +3,9 @@ from typing import Any, Optional import chromadb from chromadb import QueryResult, Settings -from flask import current_app from pydantic import BaseModel +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -111,7 +111,8 @@ class ChromaVector(BaseVector): metadata=metadata, ) docs.append(doc) - + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -133,15 +134,14 @@ class ChromaVectorFactory(AbstractVectorFactory): } dataset.index_struct = json.dumps(index_struct_dict) - config = current_app.config return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=config.get('CHROMA_HOST'), - port=int(config.get('CHROMA_PORT')), - tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT), - database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE), - auth_provider=config.get('CHROMA_AUTH_PROVIDER'), - auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'), + host=dify_config.CHROMA_HOST, + port=dify_config.CHROMA_PORT, + tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, + database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, + auth_provider=dify_config.CHROMA_AUTH_PROVIDER, + auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS, ), ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 02b715d768..5f2ab7c5fc 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -3,10 +3,10 @@ import logging from typing import Any, Optional from uuid import uuid4 -from flask import current_app from pydantic import BaseModel, model_validator from pymilvus import MilvusClient, MilvusException, connections +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) - config = current_app.config return MilvusVector( collection_name=collection_name, config=MilvusConfig( - host=config.get('MILVUS_HOST'), - port=config.get('MILVUS_PORT'), - user=config.get('MILVUS_USER'), - password=config.get('MILVUS_PASSWORD'), - secure=config.get('MILVUS_SECURE'), - database=config.get('MILVUS_DATABASE'), + host=dify_config.MILVUS_HOST, + port=dify_config.MILVUS_PORT, + user=dify_config.MILVUS_USER, + password=dify_config.MILVUS_PASSWORD, + secure=dify_config.MILVUS_SECURE, + database=dify_config.MILVUS_DATABASE, ) ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 811b08818c..241b5a8414 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -5,9 +5,9 @@ from enum import Enum from typing import Any from clickhouse_connect import get_client -from flask import current_app from pydantic import BaseModel +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -156,15 +156,14 @@ class MyScaleVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) - config = current_app.config return MyScaleVector( collection_name=collection_name, config=MyScaleConfig( - host=config.get("MYSCALE_HOST", "localhost"), - port=int(config.get("MYSCALE_PORT", 8123)), - user=config.get("MYSCALE_USER", "default"), - password=config.get("MYSCALE_PASSWORD", ""), - database=config.get("MYSCALE_DATABASE", "default"), - fts_params=config.get("MYSCALE_FTS_PARAMS", ""), + host=dify_config.MYSCALE_HOST, + port=dify_config.MYSCALE_PORT, + user=dify_config.MYSCALE_USER, + password=dify_config.MYSCALE_PASSWORD, + database=dify_config.MYSCALE_DATABASE, + fts_params=dify_config.MYSCALE_FTS_PARAMS, ), ) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 744ff2d517..d834e8ce14 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -4,11 +4,11 @@ import ssl from typing import Any, Optional from uuid import uuid4 -from flask import current_app from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - config = current_app.config open_search_config = OpenSearchConfig( - host=config.get('OPENSEARCH_HOST'), - port=config.get('OPENSEARCH_PORT'), - user=config.get('OPENSEARCH_USER'), - password=config.get('OPENSEARCH_PASSWORD'), - secure=config.get('OPENSEARCH_SECURE'), + host=dify_config.OPENSEARCH_HOST, + port=dify_config.OPENSEARCH_PORT, + user=dify_config.OPENSEARCH_USER, + password=dify_config.OPENSEARCH_PASSWORD, + secure=dify_config.OPENSEARCH_SECURE, ) return OpenSearchVector( diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 5f7723508c..f75310205c 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -6,9 +6,9 @@ from typing import Any import numpy import oracledb -from flask import current_app from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel): SQL_CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS {table_name} ( - id varchar2(100) + id varchar2(100) ,text CLOB NOT NULL ,meta JSON ,embedding vector NOT NULL -) +) """ @@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) - config = current_app.config return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=config.get("ORACLE_HOST"), - port=config.get("ORACLE_PORT"), - user=config.get("ORACLE_USER"), - password=config.get("ORACLE_PASSWORD"), - database=config.get("ORACLE_DATABASE"), + host=dify_config.ORACLE_HOST, + port=dify_config.ORACLE_PORT, + user=dify_config.ORACLE_USER, + password=dify_config.ORACLE_PASSWORD, + database=dify_config.ORACLE_DATABASE, ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 63c8edfbc3..82bdc5d4b9 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -3,7 +3,6 @@ import logging from typing import Any from uuid import UUID, uuid4 -from flask import current_app from numpy import ndarray from pgvecto_rs.sqlalchemy import Vector from pydantic import BaseModel, model_validator @@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector @@ -93,7 +93,7 @@ class PGVectoRS(BaseVector): text TEXT NOT NULL, meta JSONB NOT NULL, vector vector({dimension}) NOT NULL - ) using heap; + ) using heap; """) session.execute(create_statement) index_statement = sql_text(f""" @@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) - config = current_app.config + return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=config.get('PGVECTO_RS_HOST'), - port=config.get('PGVECTO_RS_PORT'), - user=config.get('PGVECTO_RS_USER'), - password=config.get('PGVECTO_RS_PASSWORD'), - database=config.get('PGVECTO_RS_DATABASE'), + host=dify_config.PGVECTO_RS_HOST, + port=dify_config.PGVECTO_RS_PORT, + user=dify_config.PGVECTO_RS_USER, + password=dify_config.PGVECTO_RS_PASSWORD, + database=dify_config.PGVECTO_RS_DATABASE, ), dim=dim - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 72d0a85f8d..33ca5bc028 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -5,9 +5,9 @@ from typing import Any import psycopg2.extras import psycopg2.pool -from flask import current_app from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( text TEXT NOT NULL, meta JSONB NOT NULL, embedding vector({dimension}) NOT NULL -) using heap; +) using heap; """ @@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) - config = current_app.config return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=config.get("PGVECTOR_HOST"), - port=config.get("PGVECTOR_PORT"), - user=config.get("PGVECTOR_USER"), - password=config.get("PGVECTOR_PASSWORD"), - database=config.get("PGVECTOR_DATABASE"), + host=dify_config.PGVECTOR_HOST, + port=dify_config.PGVECTOR_PORT, + user=dify_config.PGVECTOR_USER, + password=dify_config.PGVECTOR_PASSWORD, + database=dify_config.PGVECTOR_DATABASE, ), - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index bccc3a39f6..f9a9389868 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -19,6 +19,7 @@ from qdrant_client.http.models import ( ) from qdrant_client.local.qdrant_local import QdrantLocal +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -361,6 +362,8 @@ class QdrantVector(BaseVector): metadata=metadata, ) docs.append(doc) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -444,11 +447,11 @@ class QdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=config.get('QDRANT_URL'), - api_key=config.get('QDRANT_API_KEY'), + endpoint=dify_config.QDRANT_URL, + api_key=dify_config.QDRANT_API_KEY, root_path=config.root_path, - timeout=config.get('QDRANT_CLIENT_TIMEOUT'), - grpc_port=config.get('QDRANT_GRPC_PORT'), - prefer_grpc=config.get('QDRANT_GRPC_ENABLED') + timeout=dify_config.QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.QDRANT_GRPC_PORT, + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED ) ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 4fe1df717a..2e0bd6f303 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -2,7 +2,6 @@ import json import uuid from typing import Any, Optional -from flask import current_app from pydantic import BaseModel, model_validator from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import text as sql_text @@ -19,6 +18,7 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document from extensions.ext_redis import redis_client @@ -85,7 +85,7 @@ class RelytVector(BaseVector): document TEXT NOT NULL, metadata JSON NOT NULL, embedding vector({dimension}) NOT NULL - ) using heap; + ) using heap; """) session.execute(create_statement) index_statement = sql_text(f""" @@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.RELYT, collection_name)) - config = current_app.config return RelytVector( collection_name=collection_name, config=RelytConfig( - host=config.get('RELYT_HOST'), - port=config.get('RELYT_PORT'), - user=config.get('RELYT_USER'), - password=config.get('RELYT_PASSWORD'), - database=config.get('RELYT_DATABASE'), + host=dify_config.RELYT_HOST, + port=dify_config.RELYT_PORT, + user=dify_config.RELYT_USER, + password=dify_config.RELYT_PASSWORD, + database=dify_config.RELYT_DATABASE, ), group_id=dataset.id ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3af85854d2..3325a1028e 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,13 +1,13 @@ import json from typing import Any, Optional -from flask import current_app from pydantic import BaseModel from tcvectordb import VectorDBClient from tcvectordb.model import document, enum from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -198,8 +198,6 @@ class TencentVector(BaseVector): self._db.drop_collection(name=self._collection_name) - - class TencentVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: @@ -212,16 +210,15 @@ class TencentVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) - config = current_app.config return TencentVector( collection_name=collection_name, config=TencentConfig( - url=config.get('TENCENT_VECTOR_DB_URL'), - api_key=config.get('TENCENT_VECTOR_DB_API_KEY'), - timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'), - username=config.get('TENCENT_VECTOR_DB_USERNAME'), - database=config.get('TENCENT_VECTOR_DB_DATABASE'), - shard=config.get('TENCENT_VECTOR_DB_SHARD'), - replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'), + url=dify_config.TENCENT_VECTOR_DB_URL, + api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, + timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, + username=dify_config.TENCENT_VECTOR_DB_USERNAME, + database=dify_config.TENCENT_VECTOR_DB_DATABASE, + shard=dify_config.TENCENT_VECTOR_DB_SHARD, + replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, ) - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 5922db1176..d3685c0991 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -3,12 +3,12 @@ import logging from typing import Any import sqlalchemy -from flask import current_app from pydantic import BaseModel, model_validator from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -198,8 +198,8 @@ class TiDBVector(BaseVector): with Session(self._engine) as session: select_statement = sql_text( f"""SELECT meta, text, distance FROM ( - SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance - FROM {self._collection_name} + SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance + FROM {self._collection_name} ORDER BY distance LIMIT {top_k} ) t WHERE distance < {distance};""" @@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) - config = current_app.config return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=config.get('TIDB_VECTOR_HOST'), - port=config.get('TIDB_VECTOR_PORT'), - user=config.get('TIDB_VECTOR_USER'), - password=config.get('TIDB_VECTOR_PASSWORD'), - database=config.get('TIDB_VECTOR_DATABASE'), - program_name=config.get('APPLICATION_NAME'), + host=dify_config.TIDB_VECTOR_HOST, + port=dify_config.TIDB_VECTOR_PORT, + user=dify_config.TIDB_VECTOR_USER, + password=dify_config.TIDB_VECTOR_PASSWORD, + database=dify_config.TIDB_VECTOR_DATABASE, + program_name=dify_config.APPLICATION_NAME, ), - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index dbd8b6284b..0b1d58856c 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -61,9 +61,14 @@ class BaseVector(ABC): doc_id = text.metadata['doc_id'] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: + # FIXME: Mutation to loop iterable `texts` during iteration texts.remove(text) return texts def _get_uuids(self, texts: list[Document]) -> list[str]: return [text.metadata['doc_id'] for text in texts] + + @property + def collection_name(self): + return self._collection_name diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index f8b58e1b9a..509273e8ea 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from flask import current_app - +from configs import dify_config from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -10,6 +9,7 @@ from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document +from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -37,8 +37,7 @@ class Vector: self._vector_processor = self._init_vector() def _init_vector(self) -> BaseVector: - config = current_app.config - vector_type = config.get('VECTOR_STORE') + vector_type = dify_config.VECTOR_STORE if self._dataset.index_struct_dict: vector_type = self._dataset.index_struct_dict['type'] @@ -136,6 +135,10 @@ class Vector: def delete(self) -> None: self._vector_processor.delete() + # delete collection redis cache + if self._vector_processor.collection_name: + collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name) + redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: model_manager = ModelManager() @@ -154,6 +157,7 @@ class Vector: doc_id = text.metadata['doc_id'] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: + # FIXME: Mutation to loop iterable `texts` during iteration texts.remove(text) return texts diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index b7c5c96a7d..87fc5ff158 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -4,9 +4,9 @@ from typing import Any, Optional import requests import weaviate -from flask import current_app from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -216,7 +216,8 @@ class WeaviateVector(BaseVector): if score > score_threshold: doc.metadata['score'] = score docs.append(doc) - + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -281,9 +282,9 @@ class WeaviateVectorFactory(AbstractVectorFactory): return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=current_app.config.get('WEAVIATE_ENDPOINT'), - api_key=current_app.config.get('WEAVIATE_API_KEY'), - batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE')) + endpoint=dify_config.WEAVIATE_ENDPOINT, + api_key=dify_config.WEAVIATE_API_KEY, + batch_size=dify_config.WEAVIATE_BATCH_SIZE ), attributes=attributes ) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 909bfdc137..d01cf48fac 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -5,8 +5,8 @@ from typing import Union from urllib.parse import unquote import requests -from flask import current_app +from configs import dify_config from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -94,9 +94,9 @@ class ExtractProcessor: storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() - etl_type = current_app.config['ETL_TYPE'] - unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] - unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY'] + etl_type = dify_config.ETL_TYPE + unstructured_api_url = dify_config.UNSTRUCTURED_API_URL + unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY if etl_type == 'Unstructured': if file_extension == '.xlsx' or file_extension == '.xls': extractor = ExcelExtractor(file_path) diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 7c6101010e..9535455909 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -3,8 +3,8 @@ import logging from typing import Any, Optional import requests -from flask import current_app +from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor): self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: - integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') + integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( "Must specify `integration_token` or set environment " diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 9045966da9..ac4a56319b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -8,8 +8,8 @@ from urllib.parse import urlparse import requests from docx import Document as DocxDocument -from flask import current_app +from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor): storage.save(file_key, rel.target_part.blob) # save file to db - config = current_app.config upload_file = UploadFile( tenant_id=self.tenant_id, - storage_type=config['STORAGE_TYPE'], + storage_type=dify_config.STORAGE_TYPE, key=file_key, name=file_key, size=0, @@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" return image_map diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index edc16c821a..33e78ce8c5 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional -from flask import current_app - +from configs import dify_config from core.model_manager import ModelInstance from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.models.document import Document @@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC): # The user-defined segmentation rule rules = processing_rule['rules'] segmentation = rules["segmentation"] - max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH']) + max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ea2a194a68..c1f5e0820c 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -28,7 +28,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -464,7 +464,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/core/rag/retrieval/retrival_methods.py b/api/core/rag/retrieval/retrival_methods.py index 9b7907013d..12aa28a51c 100644 --- a/api/core/rag/retrieval/retrival_methods.py +++ b/api/core/rag/retrieval/retrival_methods.py @@ -1,15 +1,15 @@ from enum import Enum -class RetrievalMethod(str, Enum): +class RetrievalMethod(Enum): SEMANTIC_SEARCH = 'semantic_search' FULL_TEXT_SEARCH = 'full_text_search' HYBRID_SEARCH = 'hybrid_search' @staticmethod def is_support_semantic_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH} + return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} @staticmethod def is_support_fulltext_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH} + return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index fa13629ef7..3a3ff64426 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -30,3 +30,4 @@ - feishu - feishu_base - slack +- tianditu diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index 61609947fa..f985deade5 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -1,5 +1,5 @@ +import base64 import random -from base64 import b64decode from typing import Any, Union from openai import OpenAI @@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool): result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) + mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) + blob_message = self.create_blob_message(blob=blob_image, + meta={'mime_type': mime_type}, + save_as=self.VARIABLE_KEY.IMAGE.value) + result.append(blob_message) return result + @staticmethod + def _decode_image(base64_image: str) -> tuple[str, bytes]: + """ + Decode a base64 encoded image. If the image is not prefixed with a MIME type, + it assumes 'image/png' as the default. + + :param base64_image: Base64 encoded image string + :return: A tuple containing the MIME type and the decoded image bytes + """ + if DallE3Tool._is_plain_base64(base64_image): + return 'image/png', base64.b64decode(base64_image) + else: + return DallE3Tool._extract_mime_and_data(base64_image) + + @staticmethod + def _is_plain_base64(encoded_str: str) -> bool: + """ + Check if the given encoded string is plain base64 without a MIME type prefix. + + :param encoded_str: Base64 encoded image string + :return: True if the string is plain base64, False otherwise + """ + return not encoded_str.startswith('data:image') + + @staticmethod + def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: + """ + Extract MIME type and image data from a base64 encoded string with a MIME type prefix. + + :param encoded_str: Base64 encoded image string with MIME type prefix + :return: A tuple containing the MIME type and the decoded image bytes + """ + mime_type = encoded_str.split(';')[0].split(':')[1] + image_data_base64 = encoded_str.split(',')[1] + decoded_data = base64.b64decode(image_data_base64) + return mime_type, decoded_data + @staticmethod def _generate_random_id(length=8): characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' diff --git a/api/core/tools/provider/builtin/getimgai/_assets/icon.svg b/api/core/tools/provider/builtin/getimgai/_assets/icon.svg new file mode 100644 index 0000000000..6b2513386d --- /dev/null +++ b/api/core/tools/provider/builtin/getimgai/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.py b/api/core/tools/provider/builtin/getimgai/getimgai.py new file mode 100644 index 0000000000..c81d5fa333 --- /dev/null +++ b/api/core/tools/provider/builtin/getimgai/getimgai.py @@ -0,0 +1,22 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.getimgai.tools.text2image import Text2ImageTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GetImgAIProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + # Example validation using the text2image tool + Text2ImageTool().fork_tool_runtime( + runtime={"credentials": credentials} + ).invoke( + user_id='', + tool_parameters={ + "prompt": "A fire egg", + "response_format": "url", + "style": "photorealism", + } + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.yaml b/api/core/tools/provider/builtin/getimgai/getimgai.yaml new file mode 100644 index 0000000000..c9db0a9e22 --- /dev/null +++ b/api/core/tools/provider/builtin/getimgai/getimgai.yaml @@ -0,0 +1,29 @@ +identity: + author: Matri Qi + name: getimgai + label: + en_US: getimg.ai + zh_CN: getimg.ai + description: + en_US: GetImg API integration for image generation and scraping. + icon: icon.svg + tags: + - image +credentials_for_provider: + getimg_api_key: + type: secret-input + required: true + label: + en_US: getimg.ai API Key + placeholder: + en_US: Please input your getimg.ai API key + help: + en_US: Get your getimg.ai API key from your getimg.ai account settings. If you are using a self-hosted version, you may enter any key at your convenience. + url: https://dashboard.getimg.ai/api-keys + base_url: + type: text-input + required: false + label: + en_US: getimg.ai server's Base URL + placeholder: + en_US: https://api.getimg.ai/v1 diff --git a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py new file mode 100644 index 0000000000..e28c57649c --- /dev/null +++ b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py @@ -0,0 +1,59 @@ +import logging +import time +from collections.abc import Mapping +from typing import Any + +import requests +from requests.exceptions import HTTPError + +logger = logging.getLogger(__name__) + +class GetImgAIApp: + def __init__(self, api_key: str | None = None, base_url: str | None = None): + self.api_key = api_key + self.base_url = base_url or 'https://api.getimg.ai/v1' + if not self.api_key: + raise ValueError("API key is required") + + def _prepare_headers(self): + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + return headers + + def _request( + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, + ) -> Mapping[str, Any] | None: + for i in range(retries): + try: + response = requests.request(method, url, json=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: + time.sleep(backoff_factor * (2 ** i)) + else: + raise + return None + + def text2image( + self, mode: str, **kwargs + ): + data = kwargs['params'] + if not data.get('prompt'): + raise ValueError("Prompt is required") + + endpoint = f'{self.base_url}/{mode}/text-to-image' + headers = self._prepare_headers() + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request('POST', endpoint, data, headers) + if response is None: + raise HTTPError("Failed to initiate getimg.ai after multiple retries") + return response diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.py b/api/core/tools/provider/builtin/getimgai/tools/text2image.py new file mode 100644 index 0000000000..dad7314479 --- /dev/null +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.py @@ -0,0 +1,39 @@ +import json +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.getimgai.getimgai_appx import GetImgAIApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class Text2ImageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url']) + + options = { + 'style': tool_parameters.get('style'), + 'prompt': tool_parameters.get('prompt'), + 'aspect_ratio': tool_parameters.get('aspect_ratio'), + 'output_format': tool_parameters.get('output_format', 'jpeg'), + 'response_format': tool_parameters.get('response_format', 'url'), + 'width': tool_parameters.get('width'), + 'height': tool_parameters.get('height'), + 'steps': tool_parameters.get('steps'), + 'negative_prompt': tool_parameters.get('negative_prompt'), + 'prompt_2': tool_parameters.get('prompt_2'), + } + options = {k: v for k, v in options.items() if v} + + text2image_result = app.text2image( + mode=tool_parameters.get('mode', 'essential-v2'), + params=options, + wait=True + ) + + if not isinstance(text2image_result, str): + text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4) + + if not text2image_result: + return self.create_text_message("getimg.ai request failed.") + + return self.create_text_message(text2image_result) diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.yaml b/api/core/tools/provider/builtin/getimgai/tools/text2image.yaml new file mode 100644 index 0000000000..d972186f56 --- /dev/null +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.yaml @@ -0,0 +1,167 @@ +identity: + name: text2image + author: Matri Qi + label: + en_US: text2image + icon: icon.svg +description: + human: + en_US: Generate image via getimg.ai. + llm: This tool is used to generate image from prompt or image via https://getimg.ai. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + human_description: + en_US: The text prompt used to generate the image. The getimg.aier will generate an image based on this prompt. + llm_description: this prompt text will be used to generate image. + form: llm + - name: mode + type: select + required: false + label: + en_US: mode + human_description: + en_US: The getimg.ai mode to use. The mode determines the endpoint used to generate the image. + form: form + options: + - value: "essential-v2" + label: + en_US: essential-v2 + - value: stable-diffusion-xl + label: + en_US: stable-diffusion-xl + - value: stable-diffusion + label: + en_US: stable-diffusion + - value: latent-consistency + label: + en_US: latent-consistency + - name: style + type: select + required: false + label: + en_US: style + human_description: + en_US: The style preset to use. The style preset guides the generation towards a particular style. It's just efficient for `Essential V2` mode. + form: form + options: + - value: photorealism + label: + en_US: photorealism + - value: anime + label: + en_US: anime + - value: art + label: + en_US: art + - name: aspect_ratio + type: select + required: false + label: + en_US: "aspect ratio" + human_description: + en_US: The aspect ratio of the generated image. It's just efficient for `Essential V2` mode. + form: form + options: + - value: "1:1" + label: + en_US: "1:1" + - value: "4:5" + label: + en_US: "4:5" + - value: "5:4" + label: + en_US: "5:4" + - value: "2:3" + label: + en_US: "2:3" + - value: "3:2" + label: + en_US: "3:2" + - value: "4:7" + label: + en_US: "4:7" + - value: "7:4" + label: + en_US: "7:4" + - name: output_format + type: select + required: false + label: + en_US: "output format" + human_description: + en_US: The file format of the generated image. + form: form + options: + - value: jpeg + label: + en_US: jpeg + - value: png + label: + en_US: png + - name: response_format + type: select + required: false + label: + en_US: "response format" + human_description: + en_US: The format in which the generated images are returned. Must be one of url or b64. URLs are only valid for 1 hour after the image has been generated. + form: form + options: + - value: url + label: + en_US: url + - value: b64 + label: + en_US: b64 + - name: model + type: string + required: false + label: + en_US: model + human_description: + en_US: Model ID supported by this pipeline and family. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode. + form: form + - name: negative_prompt + type: string + required: false + label: + en_US: negative prompt + human_description: + en_US: Text input that will not guide the image generation. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode. + form: form + - name: prompt_2 + type: string + required: false + label: + en_US: prompt2 + human_description: + en_US: Prompt sent to second tokenizer and text encoder. If not defined, prompt is used in both text-encoders. It's just efficient for `Stable Diffusion XL` mode. + form: form + - name: width + type: number + required: false + label: + en_US: width + human_description: + en_US: he width of the generated image in pixels. Width needs to be multiple of 64. + form: form + - name: height + type: number + required: false + label: + en_US: height + human_description: + en_US: he height of the generated image in pixels. Height needs to be multiple of 64. + form: form + - name: steps + type: number + required: false + label: + en_US: steps + human_description: + en_US: The number of denoising steps. More steps usually can produce higher quality images, but take more time to generate. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode. + form: form diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index b09e494881..1b49cfe2f3 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -19,28 +19,29 @@ class JSONDeleteTool(BuiltinTool): content = tool_parameters.get('content', '') if not content: return self.create_text_message('Invalid parameter content') - + # Get query query = tool_parameters.get('query', '') if not query: return self.create_text_message('Invalid parameter query') - + + ensure_ascii = tool_parameters.get('ensure_ascii', True) try: - result = self._delete(content, query) + result = self._delete(content, query, ensure_ascii) return self.create_text_message(str(result)) except Exception as e: return self.create_text_message(f'Failed to delete JSON content: {str(e)}') - def _delete(self, origin_json: str, query: str) -> str: + def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str: try: input_data = json.loads(origin_json) expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $ - + matches = expr.find(input_data) - + if not matches: - return json.dumps(input_data, ensure_ascii=True) # No changes if no matches found - + return json.dumps(input_data, ensure_ascii=ensure_ascii) # No changes if no matches found + for match in matches: if isinstance(match.context.value, dict): # Delete key from dictionary @@ -53,7 +54,7 @@ class JSONDeleteTool(BuiltinTool): parent = match.context.parent if parent: del parent.value[match.path.fields[-1]] - - return json.dumps(input_data, ensure_ascii=True) + + return json.dumps(input_data, ensure_ascii=ensure_ascii) except Exception as e: - raise Exception(f"Delete operation failed: {str(e)}") \ No newline at end of file + raise Exception(f"Delete operation failed: {str(e)}") diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.yaml b/api/core/tools/provider/builtin/json_process/tools/delete.yaml index 4cfa90b861..4d390e40d1 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.yaml +++ b/api/core/tools/provider/builtin/json_process/tools/delete.yaml @@ -38,3 +38,15 @@ parameters: pt_BR: JSONPath query to locate the element to delete llm_description: JSONPath query to locate the element to delete form: llm + - name: ensure_ascii + type: boolean + default: true + label: + en_US: Ensure ASCII + zh_Hans: 确保 ASCII + pt_BR: Ensure ASCII + human_description: + en_US: Ensure the JSON output is ASCII encoded + zh_Hans: 确保输出的 JSON 是 ASCII 编码 + pt_BR: Ensure the JSON output is ASCII encoded + form: form diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index aa5986e2b4..27e34f1ff3 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -19,31 +19,31 @@ class JSONParseTool(BuiltinTool): content = tool_parameters.get('content', '') if not content: return self.create_text_message('Invalid parameter content') - + # get query query = tool_parameters.get('query', '') if not query: return self.create_text_message('Invalid parameter query') - + # get new value new_value = tool_parameters.get('new_value', '') if not new_value: return self.create_text_message('Invalid parameter new_value') - + # get insert position index = tool_parameters.get('index') - + # get create path create_path = tool_parameters.get('create_path', False) - + + ensure_ascii = tool_parameters.get('ensure_ascii', True) try: - result = self._insert(content, query, new_value, index, create_path) + result = self._insert(content, query, new_value, ensure_ascii, index, create_path) return self.create_text_message(str(result)) except Exception: return self.create_text_message('Failed to insert JSON content') - - def _insert(self, origin_json, query, new_value, index=None, create_path=False): + def _insert(self, origin_json, query, new_value, ensure_ascii: bool, index=None, create_path=False): try: input_data = json.loads(origin_json) expr = parse(query) @@ -51,9 +51,9 @@ class JSONParseTool(BuiltinTool): new_value = json.loads(new_value) except json.JSONDecodeError: new_value = new_value - + matches = expr.find(input_data) - + if not matches and create_path: # create new path path_parts = query.strip('$').strip('.').split('.') @@ -91,7 +91,7 @@ class JSONParseTool(BuiltinTool): else: # replace old value with new value match.full_path.update(input_data, new_value) - - return json.dumps(input_data, ensure_ascii=True) + + return json.dumps(input_data, ensure_ascii=ensure_ascii) except Exception as e: - return str(e) \ No newline at end of file + return str(e) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.yaml b/api/core/tools/provider/builtin/json_process/tools/insert.yaml index 66a6ff9929..63e7816455 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.yaml +++ b/api/core/tools/provider/builtin/json_process/tools/insert.yaml @@ -75,3 +75,15 @@ parameters: zh_Hans: 否 pt_BR: "No" form: form + - name: ensure_ascii + type: boolean + default: true + label: + en_US: Ensure ASCII + zh_Hans: 确保 ASCII + pt_BR: Ensure ASCII + human_description: + en_US: Ensure the JSON output is ASCII encoded + zh_Hans: 确保输出的 JSON 是 ASCII 编码 + pt_BR: Ensure the JSON output is ASCII encoded + form: form diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index b246afc07e..ecd39113ae 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -19,33 +19,34 @@ class JSONParseTool(BuiltinTool): content = tool_parameters.get('content', '') if not content: return self.create_text_message('Invalid parameter content') - + # get json filter json_filter = tool_parameters.get('json_filter', '') if not json_filter: return self.create_text_message('Invalid parameter json_filter') + ensure_ascii = tool_parameters.get('ensure_ascii', True) try: - result = self._extract(content, json_filter) + result = self._extract(content, json_filter, ensure_ascii) return self.create_text_message(str(result)) except Exception: return self.create_text_message('Failed to extract JSON content') # Extract data from JSON content - def _extract(self, content: str, json_filter: str) -> str: + def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str: try: input_data = json.loads(content) expr = parse(json_filter) result = [match.value for match in expr.find(input_data)] - + if len(result) == 1: result = result[0] - + if isinstance(result, dict | list): - return json.dumps(result, ensure_ascii=True) + return json.dumps(result, ensure_ascii=ensure_ascii) elif isinstance(result, str | int | float | bool) or result is None: return str(result) else: return repr(result) except Exception as e: - return str(e) \ No newline at end of file + return str(e) diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.yaml b/api/core/tools/provider/builtin/json_process/tools/parse.yaml index b619dcde94..c35f4eac07 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.yaml +++ b/api/core/tools/provider/builtin/json_process/tools/parse.yaml @@ -38,3 +38,15 @@ parameters: pt_BR: JSON fields to be parsed llm_description: JSON fields to be parsed form: llm + - name: ensure_ascii + type: boolean + default: true + label: + en_US: Ensure ASCII + zh_Hans: 确保 ASCII + pt_BR: Ensure ASCII + human_description: + en_US: Ensure the JSON output is ASCII encoded + zh_Hans: 确保输出的 JSON 是 ASCII 编码 + pt_BR: Ensure the JSON output is ASCII encoded + form: form diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index 9f127b9d06..be696bce0e 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -19,61 +19,62 @@ class JSONReplaceTool(BuiltinTool): content = tool_parameters.get('content', '') if not content: return self.create_text_message('Invalid parameter content') - + # get query query = tool_parameters.get('query', '') if not query: return self.create_text_message('Invalid parameter query') - + # get replace value replace_value = tool_parameters.get('replace_value', '') if not replace_value: return self.create_text_message('Invalid parameter replace_value') - + # get replace model replace_model = tool_parameters.get('replace_model', '') if not replace_model: return self.create_text_message('Invalid parameter replace_model') + ensure_ascii = tool_parameters.get('ensure_ascii', True) try: if replace_model == 'pattern': # get replace pattern replace_pattern = tool_parameters.get('replace_pattern', '') if not replace_pattern: return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value) + result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii) elif replace_model == 'key': - result = self._replace_key(content, query, replace_value) + result = self._replace_key(content, query, replace_value, ensure_ascii) elif replace_model == 'value': - result = self._replace_value(content, query, replace_value) + result = self._replace_value(content, query, replace_value, ensure_ascii) return self.create_text_message(str(result)) except Exception: return self.create_text_message('Failed to replace JSON content') # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str) -> str: + def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool) -> str: try: input_data = json.loads(content) expr = parse(query) - + matches = expr.find(input_data) - + for match in matches: new_value = match.value.replace(replace_pattern, replace_value) match.full_path.update(input_data, new_value) - - return json.dumps(input_data, ensure_ascii=True) + + return json.dumps(input_data, ensure_ascii=ensure_ascii) except Exception as e: return str(e) - + # Replace key - def _replace_key(self, content: str, query: str, replace_value: str) -> str: + def _replace_key(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str: try: input_data = json.loads(content) expr = parse(query) - + matches = expr.find(input_data) - + for match in matches: parent = match.context.value if isinstance(parent, dict): @@ -86,21 +87,21 @@ class JSONReplaceTool(BuiltinTool): if isinstance(item, dict) and old_key in item: value = item.pop(old_key) item[replace_value] = value - return json.dumps(input_data, ensure_ascii=True) + return json.dumps(input_data, ensure_ascii=ensure_ascii) except Exception as e: return str(e) - + # Replace value - def _replace_value(self, content: str, query: str, replace_value: str) -> str: + def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str: try: input_data = json.loads(content) expr = parse(query) - + matches = expr.find(input_data) - + for match in matches: match.full_path.update(input_data, replace_value) - - return json.dumps(input_data, ensure_ascii=True) + + return json.dumps(input_data, ensure_ascii=ensure_ascii) except Exception as e: - return str(e) \ No newline at end of file + return str(e) diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.yaml b/api/core/tools/provider/builtin/json_process/tools/replace.yaml index 556be5e8b2..cf4b1dc63f 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.yaml +++ b/api/core/tools/provider/builtin/json_process/tools/replace.yaml @@ -93,3 +93,15 @@ parameters: zh_Hans: 字符串替换 pt_BR: replace string form: form + - name: ensure_ascii + type: boolean + default: true + label: + en_US: Ensure ASCII + zh_Hans: 确保 ASCII + pt_BR: Ensure ASCII + human_description: + en_US: Ensure the JSON output is ASCII encoded + zh_Hans: 确保输出的 JSON 是 ASCII 编码 + pt_BR: Ensure the JSON output is ASCII encoded + form: form diff --git a/api/core/tools/provider/builtin/spider/_assets/icon.svg b/api/core/tools/provider/builtin/spider/_assets/icon.svg new file mode 100644 index 0000000000..604a09d01d --- /dev/null +++ b/api/core/tools/provider/builtin/spider/_assets/icon.svg @@ -0,0 +1 @@ +Spider v1 Logo diff --git a/api/core/tools/provider/builtin/spider/spider.py b/api/core/tools/provider/builtin/spider/spider.py new file mode 100644 index 0000000000..6fa431b6bb --- /dev/null +++ b/api/core/tools/provider/builtin/spider/spider.py @@ -0,0 +1,14 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.spider.spiderApp import Spider +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SpiderProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + app = Spider(api_key=credentials["spider_api_key"]) + app.scrape_url(url="https://spider.cloud") + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spider/spider.yaml b/api/core/tools/provider/builtin/spider/spider.yaml new file mode 100644 index 0000000000..45702c85dd --- /dev/null +++ b/api/core/tools/provider/builtin/spider/spider.yaml @@ -0,0 +1,27 @@ +identity: + author: William Espegren + name: spider + label: + en_US: Spider + zh_CN: Spider + description: + en_US: Spider API integration, returning LLM-ready data by scraping & crawling websites. + zh_CN: Spider API 集成,通过爬取和抓取网站返回 LLM-ready 数据。 + icon: icon.svg + tags: + - search + - utilities +credentials_for_provider: + spider_api_key: + type: secret-input + required: true + label: + en_US: Spider API Key + zh_CN: Spider API 密钥 + placeholder: + en_US: Please input your Spider API key + zh_CN: 请输入您的 Spider API 密钥 + help: + en_US: Get your Spider API key from your Spider dashboard + zh_CN: 从您的 Spider 仪表板中获取 Spider API 密钥。 + url: https://spider.cloud/ diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py new file mode 100644 index 0000000000..82c0df19ca --- /dev/null +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -0,0 +1,237 @@ +import os +from typing import Literal, Optional, TypedDict + +import requests + + +class RequestParamsDict(TypedDict, total=False): + url: Optional[str] + request: Optional[Literal["http", "chrome", "smart"]] + limit: Optional[int] + return_format: Optional[Literal["raw", "markdown", "html2text", "text", "bytes"]] + tld: Optional[bool] + depth: Optional[int] + cache: Optional[bool] + budget: Optional[dict[str, int]] + locale: Optional[str] + cookies: Optional[str] + stealth: Optional[bool] + headers: Optional[dict[str, str]] + anti_bot: Optional[bool] + metadata: Optional[bool] + viewport: Optional[dict[str, int]] + encoding: Optional[str] + subdomains: Optional[bool] + user_agent: Optional[str] + store_data: Optional[bool] + gpt_config: Optional[list[str]] + fingerprint: Optional[bool] + storageless: Optional[bool] + readability: Optional[bool] + proxy_enabled: Optional[bool] + respect_robots: Optional[bool] + query_selector: Optional[str] + full_resources: Optional[bool] + request_timeout: Optional[int] + run_in_background: Optional[bool] + skip_config_checks: Optional[bool] + + +class Spider: + def __init__(self, api_key: Optional[str] = None): + """ + Initialize the Spider with an API key. + + :param api_key: A string of the API key for Spider. Defaults to the SPIDER_API_KEY environment variable. + :raises ValueError: If no API key is provided. + """ + self.api_key = api_key or os.getenv("SPIDER_API_KEY") + if self.api_key is None: + raise ValueError("No API key provided") + + def api_post( + self, + endpoint: str, + data: dict, + stream: bool, + content_type: str = "application/json", + ): + """ + Send a POST request to the specified API endpoint. + + :param endpoint: The API endpoint to which the POST request is sent. + :param data: The data (dictionary) to be sent in the POST request. + :param stream: Boolean indicating if the response should be streamed. + :return: The JSON response or the raw response stream if stream is True. + """ + headers = self._prepare_headers(content_type) + response = self._post_request( + f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream + ) + + if stream: + return response + elif response.status_code == 200: + return response.json() + else: + self._handle_error(response, f"post to {endpoint}") + + def api_get( + self, endpoint: str, stream: bool, content_type: str = "application/json" + ): + """ + Send a GET request to the specified endpoint. + + :param endpoint: The API endpoint from which to retrieve data. + :return: The JSON decoded response. + """ + headers = self._prepare_headers(content_type) + response = self._get_request( + f"https://api.spider.cloud/v1/{endpoint}", headers, stream + ) + if response.status_code == 200: + return response.json() + else: + self._handle_error(response, f"get from {endpoint}") + + def get_credits(self): + """ + Retrieve the account's remaining credits. + + :return: JSON response containing the number of credits left. + """ + return self.api_get("credits", stream=False) + + def scrape_url( + self, + url: str, + params: Optional[RequestParamsDict] = None, + stream: bool = False, + content_type: str = "application/json", + ): + """ + Scrape data from the specified URL. + + :param url: The URL from which to scrape data. + :param params: Optional dictionary of additional parameters for the scrape request. + :return: JSON response containing the scraping results. + """ + + # Add { "return_format": "markdown" } to the params if not already present + if "return_format" not in params: + params["return_format"] = "markdown" + + # Set limit to 1 + params["limit"] = 1 + + return self.api_post( + "crawl", {"url": url, **(params or {})}, stream, content_type + ) + + def crawl_url( + self, + url: str, + params: Optional[RequestParamsDict] = None, + stream: bool = False, + content_type: str = "application/json", + ): + """ + Start crawling at the specified URL. + + :param url: The URL to begin crawling. + :param params: Optional dictionary with additional parameters to customize the crawl. + :param stream: Boolean indicating if the response should be streamed. Defaults to False. + :return: JSON response or the raw response stream if streaming enabled. + """ + + # Add { "return_format": "markdown" } to the params if not already present + if "return_format" not in params: + params["return_format"] = "markdown" + + return self.api_post( + "crawl", {"url": url, **(params or {})}, stream, content_type + ) + + def links( + self, + url: str, + params: Optional[RequestParamsDict] = None, + stream: bool = False, + content_type: str = "application/json", + ): + """ + Retrieve links from the specified URL. + + :param url: The URL from which to extract links. + :param params: Optional parameters for the link retrieval request. + :return: JSON response containing the links. + """ + return self.api_post( + "links", {"url": url, **(params or {})}, stream, content_type + ) + + def extract_contacts( + self, + url: str, + params: Optional[RequestParamsDict] = None, + stream: bool = False, + content_type: str = "application/json", + ): + """ + Extract contact information from the specified URL. + + :param url: The URL from which to extract contact information. + :param params: Optional parameters for the contact extraction. + :return: JSON response containing extracted contact details. + """ + return self.api_post( + "pipeline/extract-contacts", + {"url": url, **(params or {})}, + stream, + content_type, + ) + + def label( + self, + url: str, + params: Optional[RequestParamsDict] = None, + stream: bool = False, + content_type: str = "application/json", + ): + """ + Apply labeling to data extracted from the specified URL. + + :param url: The URL to label data from. + :param params: Optional parameters to guide the labeling process. + :return: JSON response with labeled data. + """ + return self.api_post( + "pipeline/label", {"url": url, **(params or {})}, stream, content_type + ) + + def _prepare_headers(self, content_type: str = "application/json"): + return { + "Content-Type": content_type, + "Authorization": f"Bearer {self.api_key}", + "User-Agent": "Spider-Client/0.0.27", + } + + def _post_request(self, url: str, data, headers, stream=False): + return requests.post(url, headers=headers, json=data, stream=stream) + + def _get_request(self, url: str, headers, stream=False): + return requests.get(url, headers=headers, stream=stream) + + def _delete_request(self, url: str, headers, stream=False): + return requests.delete(url, headers=headers, stream=stream) + + def _handle_error(self, response, action): + if response.status_code in [402, 409, 500]: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception( + f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}" + ) + else: + raise Exception( + f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py new file mode 100644 index 0000000000..64bbcc10cc --- /dev/null +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py @@ -0,0 +1,47 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.spider.spiderApp import Spider +from core.tools.tool.builtin_tool import BuiltinTool + + +class ScrapeTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + # initialize the app object with the api key + app = Spider(api_key=self.runtime.credentials['spider_api_key']) + + url = tool_parameters['url'] + mode = tool_parameters['mode'] + + options = { + 'limit': tool_parameters.get('limit', 0), + 'depth': tool_parameters.get('depth', 0), + 'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [], + 'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [], + 'readability': tool_parameters.get('readability', False), + } + + result = "" + + try: + if mode == 'scrape': + scrape_result = app.scrape_url( + url=url, + params=options, + ) + + for i in scrape_result: + result += "URL: " + i.get('url', '') + "\n" + result += "CONTENT: " + i.get('content', '') + "\n\n" + elif mode == 'crawl': + crawl_result = app.crawl_url( + url=tool_parameters['url'], + params=options, + ) + for i in crawl_result: + result += "URL: " + i.get('url', '') + "\n" + result += "CONTENT: " + i.get('content', '') + "\n\n" + except Exception as e: + return self.create_text_message("An error occured", str(e)) + + return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.yaml b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.yaml new file mode 100644 index 0000000000..5b20c2fc2f --- /dev/null +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.yaml @@ -0,0 +1,102 @@ +identity: + name: scraper_crawler + author: William Espegren + label: + en_US: Web Scraper & Crawler + zh_Hans: 网页抓取与爬虫 +description: + human: + en_US: A tool for scraping & crawling webpages. Input should be a url. + zh_Hans: 用于抓取和爬取网页的工具。输入应该是一个网址。 + llm: A tool for scraping & crawling webpages. Input should be a url. +parameters: + - name: url + type: string + required: true + label: + en_US: URL + zh_Hans: 网址 + human_description: + en_US: url to be scraped or crawled + zh_Hans: 要抓取或爬取的网址 + llm_description: url to either be scraped or crawled + form: llm + - name: mode + type: select + required: true + options: + - value: scrape + label: + en_US: scrape + zh_Hans: 抓取 + - value: crawl + label: + en_US: crawl + zh_Hans: 爬取 + default: crawl + label: + en_US: Mode + zh_Hans: 模式 + human_description: + en_US: used for selecting to either scrape the website or crawl the entire website following subpages + zh_Hans: 用于选择抓取网站或爬取整个网站及其子页面 + form: form + - name: limit + type: number + required: false + label: + en_US: maximum number of pages to crawl + zh_Hans: 最大爬取页面数 + human_description: + en_US: specify the maximum number of pages to crawl per website. the crawler will stop after reaching this limit. + zh_Hans: 指定每个网站要爬取的最大页面数。爬虫将在达到此限制后停止。 + form: form + min: 0 + default: 0 + - name: depth + type: number + required: false + label: + en_US: maximum depth of pages to crawl + zh_Hans: 最大爬取深度 + human_description: + en_US: the crawl limit for maximum depth. + zh_Hans: 最大爬取深度的限制。 + form: form + min: 0 + default: 0 + - name: blacklist + type: string + required: false + label: + en_US: url patterns to exclude + zh_Hans: 要排除的URL模式 + human_description: + en_US: blacklist a set of paths that you do not want to crawl. you can use regex patterns to help with the list. + zh_Hans: 指定一组不想爬取的路径。您可以使用正则表达式模式来帮助定义列表。 + placeholder: + en_US: /blog/*, /about + form: form + - name: whitelist + type: string + required: false + label: + en_US: URL patterns to include + zh_Hans: 要包含的URL模式 + human_description: + en_US: Whitelist a set of paths that you want to crawl, ignoring all other routes that do not match the patterns. You can use regex patterns to help with the list. + zh_Hans: 指定一组要爬取的路径,忽略所有不匹配模式的其他路由。您可以使用正则表达式模式来帮助定义列表。 + placeholder: + en_US: /blog/*, /about + form: form + - name: readability + type: boolean + required: false + label: + en_US: Pre-process the content for LLM usage + zh_Hans: 仅返回页面的主要内容 + human_description: + en_US: Use Mozilla's readability to pre-process the content for reading. This may drastically improve the content for LLM usage. + zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。 + form: form + default: false diff --git a/api/core/tools/provider/builtin/tianditu/_assets/icon.svg b/api/core/tools/provider/builtin/tianditu/_assets/icon.svg new file mode 100644 index 0000000000..749d4bda26 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/_assets/icon.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.py b/api/core/tools/provider/builtin/tianditu/tianditu.py new file mode 100644 index 0000000000..1f96be06b0 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tianditu.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.tianditu.tools.poisearch import PoiSearchTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class TiandituProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + PoiSearchTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke(user_id='', + tool_parameters={ + 'content': '北京', + 'specify': '156110000', + }) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.yaml b/api/core/tools/provider/builtin/tianditu/tianditu.yaml new file mode 100644 index 0000000000..77af834bdc --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tianditu.yaml @@ -0,0 +1,32 @@ +identity: + author: Listeng + name: tianditu + label: + en_US: Tianditu + zh_Hans: 天地图 + pt_BR: Tianditu + description: + en_US: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region. + zh_Hans: 天地图工具可以调用天地图的接口,实现中国区域内的地名搜索、地理编码、静态地图等功能。 + pt_BR: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region. + icon: icon.svg + tags: + - utilities + - travel +credentials_for_provider: + tianditu_api_key: + type: secret-input + required: true + label: + en_US: Tianditu API Key + zh_Hans: 天地图Key + pt_BR: Tianditu API key + placeholder: + en_US: Please input your Tianditu API key + zh_Hans: 请输入你的天地图Key + pt_BR: Please input your Tianditu API key + help: + en_US: Get your Tianditu API key from Tianditu + zh_Hans: 获取您的天地图Key + pt_BR: Get your Tianditu API key from Tianditu + url: http://lbs.tianditu.gov.cn/home.html diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py new file mode 100644 index 0000000000..484a3768c8 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py @@ -0,0 +1,33 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GeocoderTool(BuiltinTool): + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + base_url = 'http://api.tianditu.gov.cn/geocoder' + + keyword = tool_parameters.get('keyword', '') + if not keyword: + return self.create_text_message('Invalid parameter keyword') + + tk = self.runtime.credentials['tianditu_api_key'] + + params = { + 'keyWord': keyword, + } + + result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json() + + return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.yaml b/api/core/tools/provider/builtin/tianditu/tools/geocoder.yaml new file mode 100644 index 0000000000..d6a168f950 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.yaml @@ -0,0 +1,26 @@ +identity: + name: geocoder + author: Listeng + label: + en_US: Get coords converted from address name + zh_Hans: 地理编码 + pt_BR: Get coords converted from address name +description: + human: + en_US: Geocoder + zh_Hans: 中国区域地理编码查询 + pt_BR: Geocoder + llm: A tool for geocoder in China +parameters: + - name: keyword + type: string + required: true + label: + en_US: keyword + zh_Hans: 搜索的关键字 + pt_BR: keyword + human_description: + en_US: keyword + zh_Hans: 搜索的关键字 + pt_BR: keyword + form: llm diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py new file mode 100644 index 0000000000..08a5b8ef42 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py @@ -0,0 +1,45 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class PoiSearchTool(BuiltinTool): + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' + base_url = 'http://api.tianditu.gov.cn/v2/search' + + keyword = tool_parameters.get('keyword', '') + if not keyword: + return self.create_text_message('Invalid parameter keyword') + + baseAddress = tool_parameters.get('baseAddress', '') + if not baseAddress: + return self.create_text_message('Invalid parameter baseAddress') + + tk = self.runtime.credentials['tianditu_api_key'] + + base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json() + + params = { + 'keyWord': keyword, + 'queryRadius': 5000, + 'queryType': 3, + 'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'], + 'start': 0, + 'count': 100, + } + + result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json() + + return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.yaml b/api/core/tools/provider/builtin/tianditu/tools/poisearch.yaml new file mode 100644 index 0000000000..01289d24e3 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.yaml @@ -0,0 +1,38 @@ +identity: + name: point_of_interest_search + author: Listeng + label: + en_US: Point of Interest search + zh_Hans: 兴趣点搜索 + pt_BR: Point of Interest search +description: + human: + en_US: Search for certain types of points of interest around a location + zh_Hans: 搜索某个位置周边的5公里内某种类型的兴趣点 + pt_BR: Search for certain types of points of interest around a location + llm: A tool for searching for certain types of points of interest around a location +parameters: + - name: keyword + type: string + required: true + label: + en_US: poi keyword + zh_Hans: 兴趣点的关键字 + pt_BR: poi keyword + human_description: + en_US: poi keyword + zh_Hans: 兴趣点的关键字 + pt_BR: poi keyword + form: llm + - name: baseAddress + type: string + required: true + label: + en_US: base current point + zh_Hans: 当前位置的关键字 + pt_BR: base current point + human_description: + en_US: base current point + zh_Hans: 当前位置的关键字 + pt_BR: base current point + form: llm diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py new file mode 100644 index 0000000000..ecac4404ca --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -0,0 +1,36 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class PoiSearchTool(BuiltinTool): + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' + base_url = 'http://api.tianditu.gov.cn/staticimage' + + keyword = tool_parameters.get('keyword', '') + if not keyword: + return self.create_text_message('Invalid parameter keyword') + + tk = self.runtime.credentials['tianditu_api_key'] + + keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json() + coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat'] + + result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content + + return self.create_blob_message(blob=result, + meta={'mime_type': 'image/png'}, + save_as=self.VARIABLE_KEY.IMAGE.value) diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.yaml b/api/core/tools/provider/builtin/tianditu/tools/staticmap.yaml new file mode 100644 index 0000000000..fc54c42806 --- /dev/null +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.yaml @@ -0,0 +1,26 @@ +identity: + name: generate_static_map + author: Listeng + label: + en_US: Generate a static map + zh_Hans: 生成静态地图 + pt_BR: Generate a static map +description: + human: + en_US: Generate a static map + zh_Hans: 生成静态地图 + pt_BR: Generate a static map + llm: A tool for generate a static map +parameters: + - name: keyword + type: string + required: true + label: + en_US: keyword + zh_Hans: 搜索的关键字 + pt_BR: keyword + human_description: + en_US: keyword + zh_Hans: 搜索的关键字 + pt_BR: keyword + form: llm diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 5b053678f3..eaf58ed5bd 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -14,7 +14,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index de2ce5858a..b1e541b8db 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -8,7 +8,7 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 04c09c7f5b..5d561911d1 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping from copy import deepcopy from enum import Enum from typing import Any, Optional, Union @@ -190,8 +191,9 @@ class Tool(BaseModel, ABC): return result - def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: + def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: # update tool_parameters + # TODO: Fix type error. if self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) @@ -208,7 +210,7 @@ class Tool(BaseModel, ABC): return result - def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: + def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]: """ Transform tool parameters type """ @@ -241,7 +243,7 @@ class Tool(BaseModel, ABC): :return: the runtime parameters """ - return self.parameters + return self.parameters or [] def get_all_runtime_parameters(self) -> list[ToolParameter]: """ diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 7615368934..0e15151aa4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,4 +1,5 @@ import json +from collections.abc import Mapping from copy import deepcopy from datetime import datetime, timezone from mimetypes import guess_type @@ -46,7 +47,7 @@ class ToolEngine: if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.get_runtime_parameters() + parameter for parameter in tool.get_runtime_parameters() or [] if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: @@ -123,8 +124,8 @@ class ToolEngine: return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: dict, - user_id: str, workflow_id: str, + def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], + user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, ) -> list[ToolInvokeMessage]: @@ -141,7 +142,9 @@ class ToolEngine: if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 - response = tool.invoke(user_id, tool_parameters) + if tool.runtime and tool.runtime.runtime_parameters: + tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} + response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters) # hit the callback handler workflow_tool_callback.on_tool_end( diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 207f009eed..f9f7c7d78a 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,9 +9,9 @@ from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 -from flask import current_app from httpx import get +from configs import dify_config from extensions.ext_database import db from extensions.ext_storage import storage from models.model import MessageFile @@ -26,25 +26,25 @@ class ToolFileManager: """ sign file to get a temporary url """ - base_url = current_app.config.get('FILES_URL') + base_url = dify_config.FILES_URL file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -53,23 +53,23 @@ class ToolFileManager: return False current_time = int(time.time()) - return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT') + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT @staticmethod - def create_file_by_raw(user_id: str, tenant_id: str, - conversation_id: Optional[str], file_binary: bytes, - mimetype: str - ) -> ToolFile: + def create_file_by_raw( + user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str + ) -> ToolFile: """ create file """ extension = guess_extension(mimetype) or '.bin' unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" + filename = f'tools/{tenant_id}/{unique_name}{extension}' storage.save(filename, file_binary) - tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=filename, mimetype=mimetype) + tool_file = ToolFile( + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype + ) db.session.add(tool_file) db.session.commit() @@ -77,9 +77,12 @@ class ToolFileManager: return tool_file @staticmethod - def create_file_by_url(user_id: str, tenant_id: str, - conversation_id: str, file_url: str, - ) -> ToolFile: + def create_file_by_url( + user_id: str, + tenant_id: str, + conversation_id: str, + file_url: str, + ) -> ToolFile: """ create file """ @@ -90,12 +93,17 @@ class ToolFileManager: mimetype = guess_type(file_url)[0] or 'octet/stream' extension = guess_extension(mimetype) or '.bin' unique_name = uuid4().hex - filename = f"tools/{tenant_id}/{unique_name}{extension}" + filename = f'tools/{tenant_id}/{unique_name}{extension}' storage.save(filename, blob) - tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=filename, - mimetype=mimetype, original_url=file_url) + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filename, + mimetype=mimetype, + original_url=file_url, + ) db.session.add(tool_file) db.session.commit() @@ -103,15 +111,15 @@ class ToolFileManager: return tool_file @staticmethod - def create_file_by_key(user_id: str, tenant_id: str, - conversation_id: str, file_key: str, - mimetype: str - ) -> ToolFile: + def create_file_by_key( + user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str + ) -> ToolFile: """ create file """ - tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=file_key, mimetype=mimetype) + tool_file = ToolFile( + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype + ) return tool_file @staticmethod @@ -123,9 +131,13 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = db.session.query(ToolFile).filter( - ToolFile.id == id, - ).first() + tool_file: ToolFile = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == id, + ) + .first() + ) if not tool_file: return None @@ -143,18 +155,31 @@ class ToolFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile = db.session.query(MessageFile).filter( - MessageFile.id == id, - ).first() + message_file: MessageFile = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) - # get tool file id - tool_file_id = message_file.url.split('/')[-1] - # trim extension - tool_file_id = tool_file_id.split('.')[0] + # Check if message_file is not None + if message_file is not None: + # get tool file id + tool_file_id = message_file.url.split('/')[-1] + # trim extension + tool_file_id = tool_file_id.split('.')[0] + else: + tool_file_id = None - tool_file: ToolFile = db.session.query(ToolFile).filter( - ToolFile.id == tool_file_id, - ).first() + + tool_file: ToolFile = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) if not tool_file: return None @@ -172,9 +197,13 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile = db.session.query(ToolFile).filter( - ToolFile.id == tool_file_id, - ).first() + tool_file: ToolFile = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) if not tool_file: return None diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b9fcecc05e..4ce5e124b9 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,8 +6,7 @@ from os import listdir, path from threading import Lock from typing import Any, Union -from flask import current_app - +from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -565,7 +564,7 @@ class ToolManager: provider_type = provider_type provider_id = provider_id if provider_type == 'builtin': - return (current_app.config.get("CONSOLE_API_URL") + return (dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" + provider_id + "/icon") @@ -574,7 +573,7 @@ class ToolManager: provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id - ) + ).first() return json.loads(provider.icon) except: return { @@ -593,4 +592,4 @@ class ToolManager: else: raise ValueError(f"provider type {provider_type} not found") -ToolManager.load_builtin_providers_cache() \ No newline at end of file +ToolManager.load_builtin_providers_cache() diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 1e7eb129a7..e52082541a 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -10,6 +10,7 @@ import unicodedata from contextlib import contextmanager from urllib.parse import unquote +import cloudscraper import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString from newspaper import Article @@ -46,29 +47,34 @@ def get_url(url: str, user_agent: str = None) -> str: supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) + if response.status_code == 200: + # check content-type + content_type = response.headers.get('Content-Type') + if content_type: + main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + else: + content_disposition = response.headers.get('Content-Disposition', '') + filename_match = re.search(r'filename="([^"]+)"', content_disposition) + if filename_match: + filename = unquote(filename_match.group(1)) + extension = re.search(r'\.(\w+)$', filename) + if extension: + main_content_type = mimetypes.guess_type(filename)[0] + + if main_content_type not in supported_content_types: + return "Unsupported content-type [{}] of URL.".format(main_content_type) + + if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: + return ExtractProcessor.load_from_url(url, return_text=True) + + response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) + elif response.status_code == 403: + scraper = cloudscraper.create_scraper() + response = scraper.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) + if response.status_code != 200: return "URL returned status code {}.".format(response.status_code) - # check content-type - content_type = response.headers.get('Content-Type') - if content_type: - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() - else: - content_disposition = response.headers.get('Content-Disposition', '') - filename_match = re.search(r'filename="([^"]+)"', content_disposition) - if filename_match: - filename = unquote(filename_match.group(1)) - extension = re.search(r'\.(\w+)$', filename) - if extension: - main_content_type = mimetypes.guess_type(filename)[0] - - if main_content_type not in supported_content_types: - return "Unsupported content-type [{}] of URL.".format(main_content_type) - - if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) - - response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) a = extract_using_readabilipy(response.text) if not a['plain_text'] or not a['plain_text'].strip(): diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 3b0d51d868..6db8adf4c2 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -6,7 +6,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType -class BaseWorkflowCallback(ABC): +class WorkflowCallback(ABC): @abstractmethod def on_workflow_run_started(self) -> None: """ @@ -78,7 +78,7 @@ class BaseWorkflowCallback(ABC): node_type: NodeType, node_run_index: int = 1, node_data: Optional[BaseNodeData] = None, - inputs: dict = None, + inputs: Optional[dict] = None, predecessor_node_id: Optional[str] = None, metadata: Optional[dict] = None) -> None: """ diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index d11352f066..e45ea834c4 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from enum import Enum from typing import Any, Optional @@ -82,9 +83,9 @@ class NodeRunResult(BaseModel): """ status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[dict] = None # node inputs + inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[dict] = None # process data - outputs: Optional[dict] = None # node outputs + outputs: Optional[Mapping[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index b12ef1c64a..38d52e0f75 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,118 +1,146 @@ -from enum import Enum -from typing import Any, Optional, Union +from collections import defaultdict +from collections.abc import Mapping, Sequence +from typing import Any, Union -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field +from typing_extensions import deprecated +from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable VariableValue = Union[str, int, float, dict, list, FileVar] -class ValueType(Enum): - """ - Value Type Enum - """ - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILE = "array[file]" - FILE = "file" +SYSTEM_VARIABLE_NODE_ID = 'sys' +ENVIRONMENT_VARIABLE_NODE_ID = 'env' class VariablePool(BaseModel): - - variables_mapping: dict[str, dict[int, VariableValue]] = Field( + # Variable dictionary is a dictionary for looking up variables by their selector. + # The first element of the selector is the node id, it's the first-level key in the dictionary. + # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the + # elements of the selector except the first one. + _variable_dictionary: dict[str, dict[int, Variable]] = Field( description='Variables mapping', - default={}, + default=defaultdict(dict) ) - user_inputs: dict = Field( + # TODO: This user inputs is not used for pool. + user_inputs: Mapping[str, Any] = Field( description='User inputs', ) - system_variables: dict[SystemVariable, Any] = Field( + system_variables: Mapping[SystemVariable, Any] = Field( description='System variables', ) - @model_validator(mode='before') - def append_system_variables(cls, v: dict) -> dict: + environment_variables: Sequence[Variable] = Field( + description="Environment variables." + ) + + def __post_init__(self): """ Append system variables - :param v: params :return: """ - v['variables_mapping'] = { - 'sys': {} - } - system_variables = v['system_variables'] - for system_variable, value in system_variables.items(): - variable_key_list_hash = hash((system_variable.value,)) - v['variables_mapping']['sys'][variable_key_list_hash] = value - return v + # Add system variables to the variable pool + for key, value in self.system_variables.items(): + self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) - def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: + # Add environment variables to the variable pool + for var in self.environment_variables or []: + self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) + + def add(self, selector: Sequence[str], value: Any, /) -> None: """ - Append variable - :param node_id: node id - :param variable_key_list: variable key list, like: ['result', 'text'] - :param value: value - :return: + Adds a variable to the variable pool. + + Args: + selector (Sequence[str]): The selector for the variable. + value (VariableValue): The value of the variable. + + Raises: + ValueError: If the selector is invalid. + + Returns: + None """ - if node_id not in self.variables_mapping: - self.variables_mapping[node_id] = {} + if len(selector) < 2: + raise ValueError('Invalid selector') - variable_key_list_hash = hash(tuple(variable_key_list)) + if value is None: + return - self.variables_mapping[node_id][variable_key_list_hash] = value + if not isinstance(value, Variable): + v = factory.build_anonymous_variable(value) + else: + v = value - def get_variable_value(self, variable_selector: list[str], - target_value_type: Optional[ValueType] = None) -> Optional[VariableValue]: + hash_key = hash(tuple(selector[1:])) + self._variable_dictionary[selector[0]][hash_key] = v + + def get(self, selector: Sequence[str], /) -> Variable | None: """ - Get variable - :param variable_selector: include node_id and variables - :param target_value_type: target value type - :return: + Retrieves the value from the variable pool based on the given selector. + + Args: + selector (Sequence[str]): The selector used to identify the variable. + + Returns: + Any: The value associated with the given selector. + + Raises: + ValueError: If the selector is invalid. """ - if len(variable_selector) < 2: - raise ValueError('Invalid value selector') - - node_id = variable_selector[0] - if node_id not in self.variables_mapping: - return None - - # fetch variable keys, pop node_id - variable_key_list = variable_selector[1:] - - variable_key_list_hash = hash(tuple(variable_key_list)) - - value = self.variables_mapping[node_id].get(variable_key_list_hash) - - if target_value_type: - if target_value_type == ValueType.STRING: - return str(value) - elif target_value_type == ValueType.NUMBER: - return int(value) - elif target_value_type == ValueType.OBJECT: - if not isinstance(value, dict): - raise ValueError('Invalid value type: object') - elif target_value_type in [ValueType.ARRAY_STRING, - ValueType.ARRAY_NUMBER, - ValueType.ARRAY_OBJECT, - ValueType.ARRAY_FILE]: - if not isinstance(value, list): - raise ValueError(f'Invalid value type: {target_value_type.value}') + if len(selector) < 2: + raise ValueError('Invalid selector') + hash_key = hash(tuple(selector[1:])) + value = self._variable_dictionary[selector[0]].get(hash_key) return value - def clear_node_variables(self, node_id: str) -> None: + @deprecated('This method is deprecated, use `get` instead.') + def get_any(self, selector: Sequence[str], /) -> Any | None: """ - Clear node variables - :param node_id: node id - :return: + Retrieves the value from the variable pool based on the given selector. + + Args: + selector (Sequence[str]): The selector used to identify the variable. + + Returns: + Any: The value associated with the given selector. + + Raises: + ValueError: If the selector is invalid. """ - if node_id in self.variables_mapping: - self.variables_mapping.pop(node_id) \ No newline at end of file + if len(selector) < 2: + raise ValueError('Invalid selector') + hash_key = hash(tuple(selector[1:])) + value = self._variable_dictionary[selector[0]].get(hash_key) + + if value is None: + return value + if isinstance(value, ArrayVariable): + return [element.value for element in value.value] + if isinstance(value, ObjectVariable): + return {k: v.value for k, v in value.value.items()} + return value.value if value else None + + def remove(self, selector: Sequence[str], /): + """ + Remove variables from the variable pool based on the given selector. + + Args: + selector (Sequence[str]): A sequence of strings representing the selector. + + Returns: + None + """ + if not selector: + return + if len(selector) == 1: + self._variable_dictionary[selector[0]] = {} + return + hash_key = hash(tuple(selector[1:])) + self._variable_dictionary[selector[0]].pop(hash_key, None) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index b6c67d585c..79659cbab9 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,7 +1,5 @@ -import json from typing import cast -from core.file.file_obj import FileVar from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter @@ -18,7 +16,7 @@ from models.workflow import WorkflowNodeExecutionStatus class AnswerNode(BaseNode): _node_data_cls = AnswerNodeData - node_type = NodeType.ANSWER + _node_type: NodeType = NodeType.ANSWER def _run(self) -> NodeRunResult: """ @@ -36,31 +34,12 @@ class AnswerNode(BaseNode): if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = self.graph_runtime_state.variable_pool.get_variable_value( + value = self.graph_runtime_state.variable_pool.get( variable_selector=value_selector ) - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, dict): - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, list): - for item in value: - if isinstance(item, FileVar): - text += item.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) - - answer += text + if value: + answer += value.markdown else: part = cast(TextGenerateRouteChunk, part) answer += part.text diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 4a02c96c10..150d417c21 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import Optional +from collections.abc import Generator, Mapping +from typing import Any, Optional from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType @@ -17,7 +17,7 @@ class BaseNode(ABC): _node_type: NodeType def __init__(self, - config: dict, + config: Mapping[str, Any], graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState, diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 52b19b6156..e3602fcd35 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -57,11 +57,8 @@ class CodeNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - - variables[variable] = value + value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + variables[variable] = value.value if value else None # Run code try: result = CodeExecutor.execute_workflow_code_template( diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 936597c481..796fa84373 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -9,7 +9,7 @@ from models.workflow import WorkflowNodeExecutionStatus class EndNode(BaseNode): _node_data_cls = EndNodeData - node_type = NodeType.END + _node_type = NodeType.END def _run(self) -> NodeRunResult: """ @@ -22,11 +22,8 @@ class EndNode(BaseNode): outputs = {} for variable_selector in output_variables: - value = self.graph_runtime_state.variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - - outputs[variable_selector.variable] = value + value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + outputs[variable_selector.variable] = value.value if value else None return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -43,7 +40,7 @@ class EndNode(BaseNode): :return: """ node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(cls._node_data_cls, node_data) + node_data = cast(EndNodeData, node_data) return cls.extract_generate_nodes_from_node_data(graph, node_data) @@ -55,7 +52,7 @@ class EndNode(BaseNode): :param node_data: node data object :return: """ - nodes = graph.get('nodes') + nodes = graph.get('nodes', []) node_mapping = {node.get('id'): node for node in nodes} variable_selectors = node_data.outputs diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 65451452c8..90d644e0e2 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -58,4 +58,3 @@ class HttpRequestNodeData(BaseNodeData): params: str body: Optional[HttpRequestNodeBody] = None timeout: Optional[HttpRequestNodeTimeout] = None - mask_authorization_header: Optional[bool] = True diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 3736c67fb7..473d85f073 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -9,7 +9,7 @@ import httpx import core.helper.ssrf_proxy as ssrf_proxy from configs import dify_config from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import ValueType, VariablePool +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeBody, @@ -212,13 +212,11 @@ class HttpExecutor: raise ValueError('self.authorization config is required') if authorization.config is None: raise ValueError('authorization config is required') - if authorization.config.type != 'bearer' and authorization.config.header is None: - raise ValueError('authorization config header is required') if self.authorization.config.api_key is None: raise ValueError('api_key is required') - if not self.authorization.config.header: + if not authorization.config.header: authorization.config.header = 'Authorization' if self.authorization.config.type == 'bearer': @@ -283,7 +281,7 @@ class HttpExecutor: # validate response return self._validate_and_parse_response(response) - def to_raw_request(self, mask_authorization_header: Optional[bool] = True) -> str: + def to_raw_request(self) -> str: """ convert to raw request """ @@ -295,16 +293,15 @@ class HttpExecutor: headers = self._assembling_headers() for k, v in headers.items(): - if mask_authorization_header: - # get authorization header - if self.authorization.type == 'api-key': - authorization_header = 'Authorization' - if self.authorization.config and self.authorization.config.header: - authorization_header = self.authorization.config.header + # get authorization header + if self.authorization.type == 'api-key': + authorization_header = 'Authorization' + if self.authorization.config and self.authorization.config.header: + authorization_header = self.authorization.config.header - if k.lower() == authorization_header.lower(): - raw_request += f'{k}: {"*" * len(v)}\n' - continue + if k.lower() == authorization_header.lower(): + raw_request += f'{k}: {"*" * len(v)}\n' + continue raw_request += f'{k}: {v}\n' @@ -336,16 +333,13 @@ class HttpExecutor: if variable_pool: variable_value_mapping = {} for variable_selector in variable_selectors: - value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector, target_value_type=ValueType.STRING - ) - - if value is None: + variable = variable_pool.get(variable_selector.value_selector) + if variable is None: raise ValueError(f'Variable {variable_selector.variable} not found') - - if escape_quotes and isinstance(value, str): - value = value.replace('"', '\\"') - + if escape_quotes and isinstance(variable.value, str): + value = variable.value.replace('"', '\\"') + else: + value = variable.value variable_value_mapping[variable_selector.variable] = value return variable_template_parser.format(variable_value_mapping), variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 290c8c0dac..05690fcc01 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -3,6 +3,7 @@ from mimetypes import guess_extension from os import path from typing import cast +from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -50,6 +51,9 @@ class HttpRequestNode(BaseNode): def _run(self) -> NodeRunResult: node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) + # TODO: Switch to use segment directly + if node_data.authorization.config and node_data.authorization.config.api_key: + node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text # init http executor http_executor = None @@ -66,9 +70,7 @@ class HttpRequestNode(BaseNode): process_data = {} if http_executor: process_data = { - 'request': http_executor.to_raw_request( - mask_authorization_header=node_data.mask_authorization_header - ), + 'request': http_executor.to_raw_request(), } return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -87,9 +89,7 @@ class HttpRequestNode(BaseNode): 'files': files, }, process_data={ - 'request': http_executor.to_raw_request( - mask_authorization_header=node_data.mask_authorization_header, - ), + 'request': http_executor.to_raw_request(), }, ) diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index da58cd0c1b..e9c325416d 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -10,7 +10,7 @@ from models.workflow import WorkflowNodeExecutionStatus class IfElseNode(BaseNode): _node_data_cls = IfElseNodeData - node_type = NodeType.IF_ELSE + _node_type = NodeType.IF_ELSE def _run(self) -> NodeRunResult: """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 1ce723133b..e6ca0335dd 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -21,7 +21,8 @@ class IterationNode(BaseIterationNode): """ Run the node. """ - iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector) + self.node_data = cast(IterationNodeData, self.node_data) + iterator = variable_pool.get_any(self.node_data.iterator_selector) if not isinstance(iterator, list): raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") @@ -64,15 +65,15 @@ class IterationNode(BaseIterationNode): """ node_data = cast(IterationNodeData, self.node_data) - variable_pool.append_variable(self.node_id, ['index'], state.index) + variable_pool.add((self.node_id, 'index'), state.index) # get the iterator value - iterator = variable_pool.get_variable_value(node_data.iterator_selector) + iterator = variable_pool.get_any(node_data.iterator_selector) if iterator is None or not isinstance(iterator, list): return if state.index < len(iterator): - variable_pool.append_variable(self.node_id, ['item'], iterator[state.index]) + variable_pool.add((self.node_id, 'item'), iterator[state.index]) def _next_iteration(self, variable_pool: VariablePool, state: IterationState): """ @@ -88,7 +89,7 @@ class IterationNode(BaseIterationNode): :return: True if iteration limit is reached, False otherwise """ node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_variable_value(node_data.iterator_selector) + iterator = variable_pool.get_any(node_data.iterator_selector) if iterator is None or not isinstance(iterator, list): return True @@ -101,9 +102,9 @@ class IterationNode(BaseIterationNode): :param variable_pool: variable pool """ output_selector = cast(IterationNodeData, self.node_data).output_selector - output = variable_pool.get_variable_value(output_selector) + output = variable_pool.get_any(output_selector) # clear the output for this iteration - variable_pool.append_variable(self.node_id, output_selector[1:], None) + variable_pool.remove([self.node_id] + output_selector[1:]) state.current_output = output if output is not None: state.outputs.append(output) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index fdba392fe1..bc3e737e22 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -21,7 +21,7 @@ from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -40,7 +40,8 @@ class KnowledgeRetrievalNode(BaseNode): node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - query = self.graph_runtime_state.variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) + query = variable.value if variable else None variables = { 'query': query } diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 42bc62a8ee..6726155487 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -44,7 +44,7 @@ from models.workflow import WorkflowNodeExecutionStatus class LLMNode(BaseNode): _node_data_cls = LLMNodeData - node_type = NodeType.LLM + _node_type = NodeType.LLM def _run(self) -> Generator[RunEvent, None, None]: """ @@ -98,7 +98,7 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value]) # type: ignore + query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, @@ -276,8 +276,8 @@ class LLMNode(BaseNode): for variable_selector in node_data.prompt_config.jinja2_variables or []: variable = variable_selector.variable - value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector + value = variable_pool.get_any( + variable_selector.value_selector ) def parse_dict(d: dict) -> str: @@ -340,7 +340,7 @@ class LLMNode(BaseNode): variable_selectors = variable_template_parser.extract_variable_selectors() for variable_selector in variable_selectors: - variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: raise ValueError(f'Variable {variable_selector.variable} not found') @@ -351,7 +351,7 @@ class LLMNode(BaseNode): query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) .extract_variable_selectors()) for variable_selector in query_variable_selectors: - variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + variable_value = variable_pool.get_any(variable_selector.value_selector) if variable_value is None: raise ValueError(f'Variable {variable_selector.variable} not found') @@ -369,7 +369,7 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) + files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) if not files: return [] @@ -388,7 +388,7 @@ class LLMNode(BaseNode): if not node_data.context.variable_selector: return - context_value = variable_pool.get_variable_value(node_data.context.variable_selector) + context_value = variable_pool.get_any(node_data.context.variable_selector) if context_value: if isinstance(context_value, str): yield RunRetrieverResourceEvent( @@ -530,7 +530,7 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) if conversation_id is None: return None diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index eebc36cc55..6d4b957671 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -71,9 +71,10 @@ class ParameterExtractorNode(LLMNode): Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - query = self.graph_runtime_state.variable_pool.get_variable_value(node_data.query) - if not query: + variable = self.graph_runtime_state.variable_pool.get(node_data.query) + if not variable: raise ValueError("Input variable content not found or is empty") + query = variable.value inputs = { 'query': query, @@ -567,7 +568,8 @@ class ParameterExtractorNode(LLMNode): variable_template_parser = VariableTemplateParser(instruction) inputs = {} for selector in variable_template_parser.extract_variable_selectors(): - inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) + variable = variable_pool.get(selector.value_selector) + inputs[selector.variable] = variable.value if variable else None return variable_template_parser.format(inputs) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index fda679adc1..a21e111b95 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -43,7 +43,8 @@ class QuestionClassifierNode(LLMNode): variable_pool = self.graph_runtime_state.variable_pool # extract variables - query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) + variable = variable_pool.get(node_data.query_variable_selector) + query = variable.value if variable else None variables = { 'query': query } @@ -305,7 +306,8 @@ class QuestionClassifierNode(LLMNode): variable_template_parser = VariableTemplateParser(template=instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: - variable_value = variable_pool.get_variable_value(variable_selector.value_selector) + variable = variable_pool.get(variable_selector.value_selector) + variable_value = variable.value if variable else None if variable_value is None: raise ValueError(f'Variable {variable_selector.variable} not found') diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index a534b5b97f..61880b82ff 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -8,7 +8,7 @@ from models.workflow import WorkflowNodeExecutionStatus class StartNode(BaseNode): _node_data_cls = StartNodeData - node_type = NodeType.START + _node_type = NodeType.START def _run(self) -> NodeRunResult: """ @@ -16,7 +16,7 @@ class StartNode(BaseNode): :return: """ # Get cleaned inputs - cleaned_inputs = self.graph_runtime_state.variable_pool.user_inputs + cleaned_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) for var in self.graph_runtime_state.variable_pool.system_variables: cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var] diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 923c2ae1ae..3406763b97 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -44,12 +44,9 @@ class TemplateTransformNode(BaseNode): # Get variables variables = {} for variable_selector in node_data.variables: - variable = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector - ) - - variables[variable] = value + variable_name = variable_selector.variable + value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector) + variables[variable_name] = value # Run code try: result = CodeExecutor.execute_workflow_code_template( diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 2e4743c483..5da5cd0727 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -29,6 +29,7 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): + # TODO: check this type value: Union[Any, list[str]] type: Literal['mixed', 'variable', 'constant'] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 449e838617..b946eb5816 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,10 +1,11 @@ +from collections.abc import Mapping, Sequence from os import path -from typing import Optional, cast +from typing import Any, cast +from core.app.segments import parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.tool.tool import Tool from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer @@ -20,6 +21,7 @@ class ToolNode(BaseNode): """ Tool Node """ + _node_data_cls = ToolNodeData _node_type = NodeType.TOOL @@ -50,23 +52,24 @@ class ToolNode(BaseNode): }, error=f'Failed to get tool runtime: {str(e)}' ) - + # get parameters - parameters = self._generate_parameters(self.graph_runtime_state.variable_pool, node_data, tool_runtime) + tool_parameters = tool_runtime.get_runtime_parameters() or [] + parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) + parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) try: messages = ToolEngine.workflow_invoke( tool=tool_runtime, tool_parameters=parameters, user_id=self.user_id, - workflow_id=self.workflow_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, ) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters, + inputs=parameters_for_log, metadata={ NodeRunMetadataKey.TOOL_INFO: tool_info }, @@ -86,21 +89,34 @@ class ToolNode(BaseNode): metadata={ NodeRunMetadataKey.TOOL_INFO: tool_info }, - inputs=parameters + inputs=parameters_for_log ) - def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict: + def _generate_parameters( + self, + *, + tool_parameters: Sequence[ToolParameter], + variable_pool: VariablePool, + node_data: ToolNodeData, + for_log: bool = False, + ) -> Mapping[str, Any]: """ - Generate parameters - """ - tool_parameters = tool_runtime.get_all_runtime_parameters() + Generate parameters based on the given tool parameters, variable pool, and node data. - def fetch_parameter(name: str) -> Optional[ToolParameter]: - return next((parameter for parameter in tool_parameters if parameter.name == name), None) + Args: + tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (ToolNodeData): The data associated with the tool node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} result = {} for parameter_name in node_data.tool_parameters: - parameter = fetch_parameter(parameter_name) + parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: continue if parameter.type == ToolParameter.ToolParameterType.FILE: @@ -108,35 +124,21 @@ class ToolNode(BaseNode): v.to_dict() for v in self._fetch_files(variable_pool) ] else: - input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': - result[parameter_name] = self._format_variable_template(input.value, variable_pool) - elif input.type == 'variable': - result[parameter_name] = variable_pool.get_variable_value(input.value) - elif input.type == 'constant': - result[parameter_name] = input.value + tool_input = node_data.tool_parameters[parameter_name] + segment_group = parser.convert_template( + template=str(tool_input.value), + variable_pool=variable_pool, + ) + result[parameter_name] = segment_group.log if for_log else segment_group.text return result - - def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str: - """ - Format variable template - """ - inputs = {} - template_parser = VariableTemplateParser(template) - for selector in template_parser.extract_variable_selectors(): - inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) - - return template_parser.format(inputs) - - def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value]) - if not files: - return [] - - return files - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) \ + def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: + # FIXME: ensure this is a ArrayVariable contains FileVariable. + variable = variable_pool.get(['sys', SystemVariable.FILES.value]) + return [file_var.value for file_var in variable.value] if variable else [] + + def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\ -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 0576e07824..144cb015da 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -18,28 +18,27 @@ class VariableAggregatorNode(BaseNode): inputs = {} if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: - for variable in node_data.variables: - value = self.graph_runtime_state.variable_pool.get_variable_value(variable) - - if value is not None: + for selector in node_data.variables: + variable = self.graph_runtime_state.variable_pool.get(selector) + if variable is not None: outputs = { - "output": value + "output": variable.value } inputs = { - '.'.join(variable[1:]): value + '.'.join(selector[1:]): variable.value } break else: for group in node_data.advanced_settings.groups: - for variable in group.variables: - value = self.graph_runtime_state.variable_pool.get_variable_value(variable) + for selector in group.variables: + variable = self.graph_runtime_state.variable_pool.get(selector) - if value is not None: + if variable is not None: outputs[group.group_name] = { - 'output': value + 'output': variable.value } - inputs['.'.join(variable[1:])] = value + inputs['.'.join(selector[1:])] = variable.value break return NodeRunResult( diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index 925c31a6aa..c43fde172c 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -1,12 +1,48 @@ import re +from collections.abc import Mapping +from typing import Any from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool -REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") +REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}') + + +def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: + """ + This is an alternative to the VariableTemplateParser class, + offering the same functionality but with better readability and ease of use. + """ + variable_keys = [match[0] for match in re.findall(REGEX, template)] + variable_keys = list(set(variable_keys)) + + # This key_selector is a tuple of (key, selector) where selector is a list of keys + # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) + key_selectors = filter( + lambda t: len(t[1]) >= 2, + ((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)), + ) + inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} + + def replacer(match): + key = match.group(1) + # return original matched string if key not found + value = inputs.get(key, match.group(0)) + if value is None: + value = '' + value = str(value) + # remove template variables if required + return re.sub(REGEX, r'{\1}', value) + + result = re.sub(REGEX, replacer, template) + result = re.sub(r'<\|.*?\|>', '', result) + return result class VariableTemplateParser: """ + !NOTE: Consider to use the new `segments` module instead of this class. + A class for parsing and manipulating template variables in a string. Rules: @@ -70,14 +106,11 @@ class VariableTemplateParser: if len(split_result) < 2: continue - variable_selectors.append(VariableSelector( - variable=variable_key, - value_selector=split_result - )) + variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result)) return variable_selectors - def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def format(self, inputs: Mapping[str, Any]) -> str: """ Formats the template string by replacing the template variables with their corresponding values. @@ -88,17 +121,19 @@ class VariableTemplateParser: Returns: The formatted string with template variables replaced by their values. """ + def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found + + if value is None: + value = '' # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) - + # remove template variables if required - if remove_template_variables: - return VariableTemplateParser.remove_template_variables(value) - return value + return VariableTemplateParser.remove_template_variables(value) prompt = re.sub(REGEX, replacer, self.template) return re.sub(r'<\|.*?\|>', '', prompt) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 98baa024ae..3cff6e8505 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,10 +1,9 @@ import logging import time -from collections.abc import Generator +from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from flask import current_app - +from configs import dify_config from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom @@ -13,7 +12,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine @@ -37,11 +35,11 @@ class WorkflowEntry: user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, - callbacks: list[BaseWorkflowCallback], - user_inputs: dict, - system_inputs: dict[SystemVariable, Any], + user_inputs: Mapping[str, Any], + system_inputs: Mapping[SystemVariable, Any], + callbacks: Sequence[BaseWorkflowCallback], call_depth: int = 0, - variable_pool: Optional[VariablePool] = None) -> Generator: + variable_pool: Optional[VariablePool] = None) -> None: """ :param workflow: Workflow instance :param user_id: user id @@ -71,9 +69,14 @@ class WorkflowEntry: if not variable_pool: variable_pool = VariablePool( system_variables=system_inputs, - user_inputs=user_inputs + user_inputs=user_inputs, + environment_variables=workflow.environment_variables, ) + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH + if call_depth > workflow_call_max_depth: + raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + # init graph graph = Graph.init( graph_config=graph_config @@ -124,11 +127,11 @@ class WorkflowEntry: return rst - def _run_workflow(self, graph_config: dict, - workflow_runtime_state: WorkflowRuntimeState, - callbacks: list[BaseWorkflowCallback], - start_node: Optional[str] = None, - end_node: Optional[str] = None) -> None: + def _run_workflow(self, workflow: Workflow, + workflow_run_state: WorkflowRunState, + callbacks: Sequence[BaseWorkflowCallback], + start_at: Optional[str] = None, + end_at: Optional[str] = None) -> None: """ Run workflow :param graph_config: workflow graph config @@ -149,12 +152,11 @@ class WorkflowEntry: error='Start node not found in workflow graph.' ) - predecessor_node: Optional[BaseNode] = None - current_iteration_node: Optional[BaseIterationNode] = None - max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") - max_execution_steps = cast(int, max_execution_steps) - max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") - max_execution_time = cast(int, max_execution_time) + predecessor_node: BaseNode | None = None + current_iteration_node: BaseIterationNode | None = None + has_entry_node = False + max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS + max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME while True: # get next nodes next_nodes = self._get_next_overall_nodes( @@ -212,7 +214,7 @@ class WorkflowEntry: # move to next iteration next_node_id = next_iteration # get next id - next_nodes = [self._get_node(workflow_run_state, graph, next_node_id, callbacks)] + next_nodes = [self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)] if not next_nodes: break @@ -423,7 +425,8 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( system_variables={}, - user_inputs={} + user_inputs={}, + environment_variables=workflow.environment_variables, ) # variable selector to variable mapping @@ -458,11 +461,11 @@ class WorkflowEntry: return node_instance, node_run_result def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: list[BaseWorkflowCallback] = None, - ) -> None: + node_id: str, + user_id: str, + user_inputs: dict, + callbacks: Sequence[BaseWorkflowCallback], + ) -> None: """ Single iteration run workflow node """ @@ -488,7 +491,8 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( system_variables={}, - user_inputs={} + user_inputs={}, + environment_variables=workflow.environment_variables, ) # variable selector to variable mapping @@ -604,7 +608,7 @@ class WorkflowEntry: for callback in callbacks: callback.on_workflow_run_started() - def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: + def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None: """ Workflow run success :param callbacks: workflow callbacks @@ -616,7 +620,7 @@ class WorkflowEntry: callback.on_workflow_run_succeeded() def _workflow_run_failed(self, error: str, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: Sequence[WorkflowCallback]) -> None: """ Workflow run failed :param error: error message @@ -629,11 +633,11 @@ class WorkflowEntry: error=error ) - def _workflow_iteration_started(self, graph: dict, + def _workflow_iteration_started(self, *, graph: Mapping[str, Any], current_iteration_node: BaseIterationNode, workflow_run_state: WorkflowRunState, predecessor_node_id: Optional[str] = None, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: Sequence[WorkflowCallback]) -> None: """ Workflow iteration started :param current_iteration_node: current iteration node @@ -666,10 +670,10 @@ class WorkflowEntry: # add steps workflow_run_state.workflow_node_steps += 1 - def _workflow_iteration_next(self, graph: dict, + def _workflow_iteration_next(self, *, graph: Mapping[str, Any], current_iteration_node: BaseIterationNode, workflow_run_state: WorkflowRunState, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: Sequence[BaseWorkflowCallback]) -> None: """ Workflow iteration next :param workflow_run_state: workflow run state @@ -696,11 +700,11 @@ class WorkflowEntry: nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] for node in nodes: - workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id')) + workflow_run_state.variable_pool.remove((node.get('id'),)) - def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: list[BaseWorkflowCallback] = None) -> None: + def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode, + workflow_run_state: WorkflowRunState, + callbacks: Sequence[BaseWorkflowCallback]) -> None: if callbacks: if isinstance(workflow_run_state.current_iteration_state, IterationState): for callback in callbacks: @@ -713,12 +717,12 @@ class WorkflowEntry: } ) - def _get_next_overall_nodes(self, workflow_run_state: WorkflowRunState, - graph: dict, - callbacks: list[BaseWorkflowCallback], - predecessor_node: Optional[BaseNode] = None, - node_start_at: Optional[str] = None, - node_end_at: Optional[str] = None) -> list[BaseNode]: + def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState, + graph: Mapping[str, Any], + callbacks: list[BaseWorkflowCallback], + predecessor_node: Optional[BaseNode] = None, + node_start_at: Optional[str] = None, + node_end_at: Optional[str] = None) -> Optional[BaseNode]: """ Get next nodes multiple target nodes in the future. @@ -804,26 +808,26 @@ class WorkflowEntry: if not target_node_cls: continue - target_node = target_node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) + target_node = target_node_cls( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, + config=target_node_config, + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth + ) - target_nodes.append(target_node) + target_nodes.append(target_node) - return target_nodes + return target_nodes def _get_node(self, workflow_run_state: WorkflowRunState, - graph: dict, + graph: Mapping[str, Any], node_id: str, - callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]: + callbacks: Sequence[WorkflowCallback]): """ Get node from graph by node id """ @@ -834,7 +838,7 @@ class WorkflowEntry: for node_config in nodes: if node_config.get('id') == node_id: node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes.get(node_type) + node_cls = node_classes[node_type] return node_cls( tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, @@ -847,8 +851,6 @@ class WorkflowEntry: workflow_call_depth=workflow_run_state.workflow_call_depth ) - return None - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ Check timeout @@ -867,10 +869,10 @@ class WorkflowEntry: if node_and_result.node_id == node_id ]) - def _run_workflow_node(self, workflow_run_state: WorkflowRunState, + def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState, node: BaseNode, predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None) -> None: + callbacks: Sequence[WorkflowCallback]) -> None: if callbacks: for callback in callbacks: callback.on_workflow_node_execute_started( @@ -973,10 +975,8 @@ class WorkflowEntry: :param variable_value: variable value :return: """ - variable_pool.append_variable( - node_id=node_id, - variable_key_list=variable_key_list, - value=variable_value + variable_pool.add( + [node_id] + variable_key_list, variable_value ) # if variable_value is a dict, then recursively append variables @@ -1025,7 +1025,7 @@ class WorkflowEntry: tenant_id: str, node_instance: BaseNode): for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs: + if variable_key not in user_inputs and not variable_pool.get(variable_selector): raise ValueError(f'Variable key {variable_key} not found in user inputs.') # fetch variable node id from variable selector @@ -1035,7 +1035,7 @@ class WorkflowEntry: # get value value = user_inputs.get(variable_key) - # temp fix for image type + # FIXME: temp fix for image type if node_instance.node_type == NodeType.LLM: new_value = [] if isinstance(value, list): @@ -1062,11 +1062,7 @@ class WorkflowEntry: value = new_value # append variable and value to variable pool - variable_pool.append_variable( - node_id=variable_node_id, - variable_key_list=variable_key_list, - value=value - ) + variable_pool.add([variable_node_id]+variable_key_list, value) class WorkflowRunFailedError(Exception): diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index e74c6c2406..a53d84c6e9 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -8,15 +8,15 @@ if [[ "${MIGRATION_ENABLED}" == "true" ]]; then fi if [[ "${MODE}" == "worker" ]]; then - celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \ + exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \ -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace,app_deletion} elif [[ "${MODE}" == "beat" ]]; then - celery -A app.celery beat --loglevel INFO + exec celery -A app.celery beat --loglevel INFO else if [[ "${DEBUG}" == "true" ]]; then - flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug else - gunicorn \ + exec gunicorn \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \ @@ -24,4 +24,4 @@ else --preload \ app:app fi -fi \ No newline at end of file +fi diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 8302f91a43..ae9a075340 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -43,15 +43,15 @@ def init_app(app: Flask) -> Celery: "schedule.clean_embedding_cache_task", "schedule.clean_unused_datasets_task", ] - + day = app.config["CELERY_BEAT_SCHEDULER_TIME"] beat_schedule = { 'clean_embedding_cache_task': { 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', - 'schedule': timedelta(days=1), + 'schedule': timedelta(days=day), }, 'clean_unused_datasets_task': { 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(minutes=3), + 'schedule': timedelta(days=day), } } celery_app.conf.update( diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 54d7ed55f8..c98c332021 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,38 @@ from flask_restful import fields +from core.app.segments import SecretVariable, Variable +from core.helper import encrypter from fields.member_fields import simple_account_fields from libs.helper import TimestampField + +class EnvironmentVariableField(fields.Raw): + def format(self, value): + # Mask secret variables values in environment_variables + if isinstance(value, SecretVariable): + return { + 'id': value.id, + 'name': value.name, + 'value': encrypter.obfuscated_token(value.value), + 'value_type': value.value_type.value, + } + elif isinstance(value, Variable): + return { + 'id': value.id, + 'name': value.name, + 'value': value.value, + 'value_type': value.value_type.value, + } + return value + + +environment_variable_fields = { + 'id': fields.String, + 'name': fields.String, + 'value': fields.Raw, + 'value_type': fields.String(attribute='value_type.value'), +} + workflow_fields = { 'id': fields.String, 'graph': fields.Raw(attribute='graph_dict'), @@ -13,4 +43,5 @@ workflow_fields = { 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_at': TimestampField, 'tool_published': fields.Boolean, + 'environment_variables': fields.List(EnvironmentVariableField()), } diff --git a/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py b/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py new file mode 100644 index 0000000000..7445f664cd --- /dev/null +++ b/api/migrations/versions/6e957a32015b_add_embedding_cache_created_at_index.py @@ -0,0 +1,32 @@ +"""add-embedding-cache-created_at_index + +Revision ID: 6e957a32015b +Revises: fecff1c3da27 +Create Date: 2024-07-19 17:21:34.414705 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '6e957a32015b' +down_revision = 'fecff1c3da27' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.create_index('created_at_idx', ['created_at'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.drop_index('created_at_idx') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py b/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py new file mode 100644 index 0000000000..ec2336da4d --- /dev/null +++ b/api/migrations/versions/8e5588e6412e_add_environment_variable_to_workflow_.py @@ -0,0 +1,33 @@ +"""add environment variable to workflow model + +Revision ID: 8e5588e6412e +Revises: 6e957a32015b +Create Date: 2024-07-22 03:27:16.042533 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '8e5588e6412e' +down_revision = '6e957a32015b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('environment_variables', sa.Text(), server_default='{}', nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('environment_variables') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py new file mode 100644 index 0000000000..271b2490de --- /dev/null +++ b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py @@ -0,0 +1,54 @@ +"""remove extra tracing app config table and add idx_dataset_permissions_tenant_id + +Revision ID: fecff1c3da27 +Revises: 408176b91ad3 +Create Date: 2024-07-19 12:03:21.217463 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'fecff1c3da27' +down_revision = '408176b91ad3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tracing_app_configs') + + with op.batch_alter_table('trace_app_config', schema=None) as batch_op: + batch_op.drop_index('tracing_app_config_app_id_idx') + + # idx_dataset_permissions_tenant_id + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + batch_op.create_index('idx_dataset_permissions_tenant_id', ['tenant_id']) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'tracing_app_configs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('tracing_provider', sa.String(length=255), nullable=True), + sa.Column('tracing_config', postgresql.JSON(astext_type=sa.Text()), nullable=True), + sa.Column( + 'created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False + ), + sa.Column( + 'updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False + ), + sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') + ) + + with op.batch_alter_table('trace_app_config', schema=None) as batch_op: + batch_op.create_index('tracing_app_config_app_id_idx', ['app_id']) + + with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: + batch_op.drop_index('idx_dataset_permissions_tenant_id') + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 23e7528d22..d36b2b9fda 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -48,7 +48,7 @@ class Account(UserMixin, db.Model): return self._current_tenant @current_tenant.setter - def current_tenant(self, value): + def current_tenant(self, value: "Tenant"): tenant = value ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first() if ta: @@ -62,7 +62,7 @@ class Account(UserMixin, db.Model): return self._current_tenant.id @current_tenant_id.setter - def current_tenant_id(self, value): + def current_tenant_id(self, value: str): try: tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ .filter(Tenant.id == value) \ diff --git a/api/models/dataset.py b/api/models/dataset.py index 02d49380bd..34dde2dcef 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,10 +9,10 @@ import re import time from json import JSONDecodeError -from flask import current_app from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB +from configs import dify_config from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage @@ -68,7 +68,7 @@ class Dataset(db.Model): @property def created_by_account(self): - return Account.query.get(self.created_by) + return db.session.get(Account, self.created_by) @property def latest_process_rule(self): @@ -117,7 +117,7 @@ class Dataset(db.Model): @property def retrieval_model_dict(self): default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -336,7 +336,7 @@ class Document(db.Model): @property def dataset_process_rule(self): if self.dataset_process_rule_id: - return DatasetProcessRule.query.get(self.dataset_process_rule_id) + return db.session.get(DatasetProcessRule, self.dataset_process_rule_id) return None @property @@ -528,7 +528,7 @@ class DocumentSegment(db.Model): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = current_app.config['SECRET_KEY'].encode() + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -560,7 +560,7 @@ class AppDatasetJoin(db.Model): @property def app(self): - return App.query.get(self.app_id) + return db.session.get(App, self.app_id) class DatasetQuery(db.Model): @@ -630,7 +630,8 @@ class Embedding(db.Model): __tablename__ = 'embeddings' __table_args__ = ( db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx') + db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'), + db.Index('created_at_idx', 'created_at') ) id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) diff --git a/api/models/model.py b/api/models/model.py index 4d67272c1a..396cd7ec63 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,10 +4,11 @@ import uuid from enum import Enum from typing import Optional -from flask import current_app, request +from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text +from configs import dify_config from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db @@ -111,7 +112,7 @@ class App(db.Model): @property def api_base_url(self): - return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] + return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip('/')) + '/v1' @property @@ -1113,7 +1114,7 @@ class Site(db.Model): @property def app_base_url(self): return ( - current_app.config['APP_WEB_URL'] if current_app.config['APP_WEB_URL'] else request.host_url.rstrip('/')) + dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.host_url.rstrip('/')) class ApiToken(db.Model): @@ -1382,7 +1383,7 @@ class TraceAppConfig(db.Model): __tablename__ = 'trace_app_config' __table_args__ = ( db.PrimaryKeyConstraint('id', name='tracing_app_config_pkey'), - db.Index('tracing_app_config_app_id_idx', 'app_id'), + db.Index('trace_app_config_app_id_idx', 'app_id'), ) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) diff --git a/api/models/workflow.py b/api/models/workflow.py index 2d6491032b..df2269cd0f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,16 @@ import json +from collections.abc import Mapping, Sequence from enum import Enum -from typing import Optional, Union +from typing import Any, Optional, Union +import contexts +from constants import HIDDEN_VALUE +from core.app.segments import ( + SecretVariable, + Variable, + factory, +) +from core.helper import encrypter from extensions.ext_database import db from libs import helper from models import StringUUID @@ -112,21 +121,22 @@ class Workflow(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_by = db.Column(StringUUID) updated_at = db.Column(db.DateTime) + _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') @property def created_by_account(self): - return Account.query.get(self.created_by) + return db.session.get(Account, self.created_by) @property def updated_by_account(self): - return Account.query.get(self.updated_by) if self.updated_by else None + return db.session.get(Account, self.updated_by) if self.updated_by else None @property - def graph_dict(self): - return json.loads(self.graph) if self.graph else None + def graph_dict(self) -> Mapping[str, Any]: + return json.loads(self.graph) if self.graph else {} @property - def features_dict(self): + def features_dict(self) -> Mapping[str, Any]: return json.loads(self.features) if self.features else {} def user_input_form(self, to_old_structure: bool = False) -> list: @@ -177,6 +187,72 @@ class Workflow(db.Model): WorkflowToolProvider.app_id == self.app_id ).first() is not None + @property + def environment_variables(self) -> Sequence[Variable]: + # TODO: find some way to init `self._environment_variables` when instance created. + if self._environment_variables is None: + self._environment_variables = '{}' + + tenant_id = contexts.tenant_id.get() + + environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) + results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] + + # decrypt secret variables value + decrypt_func = ( + lambda var: var.model_copy( + update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} + ) + if isinstance(var, SecretVariable) + else var + ) + results = list(map(decrypt_func, results)) + return results + + @environment_variables.setter + def environment_variables(self, value: Sequence[Variable]): + tenant_id = contexts.tenant_id.get() + + value = list(value) + if any(var for var in value if not var.id): + raise ValueError('environment variable require a unique id') + + # Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). + origin_variables_dictionary = {var.id: var for var in self.environment_variables} + for i, variable in enumerate(value): + if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: + value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name}) + + # encrypt secret variables value + encrypt_func = ( + lambda var: var.model_copy( + update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)} + ) + if isinstance(var, SecretVariable) + else var + ) + encrypted_vars = list(map(encrypt_func, value)) + environment_variables_json = json.dumps( + {var.name: var.model_dump() for var in encrypted_vars}, + ensure_ascii=False, + ) + self._environment_variables = environment_variables_json + + def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: + environment_variables = list(self.environment_variables) + environment_variables = [ + v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''}) + for v in environment_variables + ] + + result = { + 'graph': self.graph_dict, + 'features': self.features_dict, + 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], + } + return result + + class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum @@ -290,14 +366,14 @@ class WorkflowRun(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return Account.query.get(self.created_by) \ + return db.session.get(Account, self.created_by) \ if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) - return EndUser.query.get(self.created_by) \ + return db.session.get(EndUser, self.created_by) \ if created_by_role == CreatedByRole.END_USER else None @property @@ -500,14 +576,14 @@ class WorkflowNodeExecution(db.Model): @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return Account.query.get(self.created_by) \ + return db.session.get(Account, self.created_by) \ if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) - return EndUser.query.get(self.created_by) \ + return db.session.get(EndUser, self.created_by) \ if created_by_role == CreatedByRole.END_USER else None @property @@ -612,17 +688,17 @@ class WorkflowAppLog(db.Model): @property def workflow_run(self): - return WorkflowRun.query.get(self.workflow_run_id) + return db.session.get(WorkflowRun, self.workflow_run_id) @property def created_by_account(self): created_by_role = CreatedByRole.value_of(self.created_by_role) - return Account.query.get(self.created_by) \ + return db.session.get(Account, self.created_by) \ if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser created_by_role = CreatedByRole.value_of(self.created_by_role) - return EndUser.query.get(self.created_by) \ + return db.session.get(EndUser, self.created_by) \ if created_by_role == CreatedByRole.END_USER else None diff --git a/api/poetry.lock b/api/poetry.lock index ca967c57cd..4b90b63e9f 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1610,6 +1610,22 @@ lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"] numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"] zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] +[[package]] +name = "cloudscraper" +version = "1.2.71" +description = "A Python module to bypass Cloudflare's anti-bot page." +optional = false +python-versions = "*" +files = [ + {file = "cloudscraper-1.2.71-py2.py3-none-any.whl", hash = "sha256:76f50ca529ed2279e220837befdec892626f9511708e200d48d5bb76ded679b0"}, + {file = "cloudscraper-1.2.71.tar.gz", hash = "sha256:429c6e8aa6916d5bad5c8a5eac50f3ea53c9ac22616f6cb21b18dcc71517d0d3"}, +] + +[package.dependencies] +pyparsing = ">=2.4.7" +requests = ">=2.9.2" +requests-toolbelt = ">=0.9.1" + [[package]] name = "cohere" version = "5.2.6" @@ -2486,18 +2502,18 @@ docs = ["sphinx"] [[package]] name = "flask-sqlalchemy" -version = "3.0.5" +version = "3.1.1" description = "Add SQLAlchemy support to your Flask application." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "flask_sqlalchemy-3.0.5-py3-none-any.whl", hash = "sha256:cabb6600ddd819a9f859f36515bb1bd8e7dbf30206cc679d2b081dff9e383283"}, - {file = "flask_sqlalchemy-3.0.5.tar.gz", hash = "sha256:c5765e58ca145401b52106c0f46178569243c5da25556be2c231ecc60867c5b1"}, + {file = "flask_sqlalchemy-3.1.1-py3-none-any.whl", hash = "sha256:4ba4be7f419dc72f4efd8802d69974803c37259dd42f3913b0dcf75c9447e0a0"}, + {file = "flask_sqlalchemy-3.1.1.tar.gz", hash = "sha256:e4b68bb881802dda1a7d878b2fc84c06d1ee57fb40b874d3dc97dabfa36b8312"}, ] [package.dependencies] flask = ">=2.2.5" -sqlalchemy = ">=1.4.18" +sqlalchemy = ">=2.0.16" [[package]] name = "flatbuffers" @@ -2794,8 +2810,8 @@ files = [ [package.dependencies] cffi = {version = ">=1.12.2", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""} greenlet = [ - {version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""}, {version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""}, + {version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""}, ] "zope.event" = "*" "zope.interface" = "*" @@ -2899,12 +2915,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -4221,8 +4237,8 @@ files = [ [package.dependencies] orjson = ">=3.9.14,<4.0.0" pydantic = [ - {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] requests = ">=2,<3" @@ -5633,9 +5649,9 @@ bottleneck = {version = ">=1.3.6", optional = true, markers = "extra == \"perfor numba = {version = ">=0.56.4", optional = true, markers = "extra == \"performance\""} numexpr = {version = ">=2.8.4", optional = true, markers = "extra == \"performance\""} numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] odfpy = {version = ">=1.4.1", optional = true, markers = "extra == \"excel\""} openpyxl = {version = ">=3.1.0", optional = true, markers = "extra == \"excel\""} @@ -6194,8 +6210,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -6964,8 +6980,8 @@ grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" httpx = {version = ">=0.14.0", extras = ["http2"]} numpy = [ - {version = ">=1.26", markers = "python_version >= \"3.12\""}, {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, ] portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" @@ -7304,6 +7320,20 @@ requests = ">=2.0.0" [package.extras] rsa = ["oauthlib[signedtoken] (>=3.0.0)"] +[[package]] +name = "requests-toolbelt" +version = "1.0.0" +description = "A utility belt for advanced users of python-requests" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, + {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, +] + +[package.dependencies] +requests = ">=2.0.1,<3.0.0" + [[package]] name = "resend" version = "0.7.2" @@ -9408,4 +9438,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5c30434ef3021083e74389544da4176c49aae15f530f30647793e240823f3fef" +content-hash = "9b1821b6e5d6d44947cc011c2d635a366557582b4540b99e0ff53a3078a989e5" diff --git a/api/pyproject.toml b/api/pyproject.toml index b5d66184be..d37d4c21f0 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -121,7 +121,7 @@ flask-cors = "~4.0.0" flask-login = "~0.6.3" flask-migrate = "~4.0.5" flask-restful = "~0.3.10" -flask-sqlalchemy = "~3.0.5" +Flask-SQLAlchemy = "~3.1.1" gevent = "~23.9.1" gmpy2 = "~2.1.5" google-ai-generativelanguage = "0.6.1" @@ -193,6 +193,7 @@ twilio = "~9.0.4" vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } wikipedia = "1.4.0" yfinance = "~0.2.40" +cloudscraper = "1.2.71" ############################################################ # VDB dependencies required by vector store clients diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 3d49b487c6..ccc1062266 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -2,6 +2,7 @@ import datetime import time import click +from sqlalchemy import text from werkzeug.exceptions import NotFound import app @@ -16,16 +17,21 @@ def clean_embedding_cache_task(): clean_days = int(dify_config.CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) - page = 1 while True: try: - embeddings = db.session.query(Embedding).filter(Embedding.created_at < thirty_days_ago) \ - .order_by(Embedding.created_at.desc()).paginate(page=page, per_page=100) + embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \ + .order_by(Embedding.created_at.desc()).limit(100).all() + embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] except NotFound: break - for embedding in embeddings: - db.session.delete(embedding) - db.session.commit() - page += 1 + if embedding_ids: + for embedding_id in embedding_ids: + db.session.execute(text( + "DELETE FROM embeddings WHERE id = :embedding_id" + ), {'embedding_id': embedding_id}) + + db.session.commit() + else: + break end_at = time.perf_counter() click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green')) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 2033791ace..b2b2f82b78 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -2,6 +2,7 @@ import datetime import time import click +from sqlalchemy import func from werkzeug.exceptions import NotFound import app @@ -14,16 +15,52 @@ from models.dataset import Dataset, DatasetQuery, Document @app.celery.task(queue='dataset') def clean_unused_datasets_task(): click.echo(click.style('Start clean unused datasets indexes.', fg='green')) - clean_days = int(dify_config.CLEAN_DAY_SETTING) + clean_days = dify_config.CLEAN_DAY_SETTING start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) page = 1 while True: try: - datasets = db.session.query(Dataset).filter(Dataset.created_at < thirty_days_ago) \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + # Subquery for counting new documents + document_subquery_new = db.session.query( + Document.dataset_id, + func.count(Document.id).label('document_count') + ).filter( + Document.indexing_status == 'completed', + Document.enabled == True, + Document.archived == False, + Document.updated_at > thirty_days_ago + ).group_by(Document.dataset_id).subquery() + + # Subquery for counting old documents + document_subquery_old = db.session.query( + Document.dataset_id, + func.count(Document.id).label('document_count') + ).filter( + Document.indexing_status == 'completed', + Document.enabled == True, + Document.archived == False, + Document.updated_at < thirty_days_ago + ).group_by(Document.dataset_id).subquery() + + # Main query with join and filter + datasets = (db.session.query(Dataset) + .outerjoin( + document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id + ).outerjoin( + document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id + ).filter( + Dataset.created_at < thirty_days_ago, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0 + ).order_by( + Dataset.created_at.desc() + ).paginate(page=page, per_page=50)) + except NotFound: break + if datasets.items is None or len(datasets.items) == 0: + break page += 1 for dataset in datasets: dataset_query = db.session.query(DatasetQuery).filter( @@ -31,31 +68,23 @@ def clean_unused_datasets_task(): DatasetQuery.dataset_id == dataset.id ).all() if not dataset_query or len(dataset_query) == 0: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset.id, - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at > thirty_days_ago - ).all() - if not documents or len(documents) == 0: - try: - # remove index - index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() - index_processor.clean(dataset, None) + try: + # remove index + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) - # update document - update_params = { - Document.enabled: False - } + # update document + update_params = { + Document.enabled: False + } - Document.query.filter_by(dataset_id=dataset.id).update(update_params) - db.session.commit() - click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), - fg='green')) - except Exception as e: - click.echo( - click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.commit() + click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), + fg='green')) + except Exception as e: + click.echo( + click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) end_at = time.perf_counter() click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green')) diff --git a/api/services/account_service.py b/api/services/account_service.py index 0bcbe8b2c0..d73cec2697 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -47,7 +47,7 @@ class AccountService: ) @staticmethod - def load_user(user_id: str) -> Account: + def load_user(user_id: str) -> None | Account: account = Account.query.filter_by(id=user_id).first() if not account: return None @@ -55,7 +55,7 @@ class AccountService: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: raise Unauthorized("Account is banned or closed.") - current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 050295002e..3764166333 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -3,6 +3,7 @@ import logging import httpx import yaml # type: ignore +from core.app.segments import factory from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_database import db from models.account import Account @@ -150,7 +151,7 @@ class AppDslService: ) @classmethod - def export_dsl(cls, app_model: App) -> str: + def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: """ Export app :param app_model: App instance @@ -171,7 +172,7 @@ class AppDslService: } if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - cls._append_workflow_export_data(export_data, app_model) + cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) else: cls._append_model_config_export_data(export_data, app_model) @@ -235,13 +236,16 @@ class AppDslService: ) # init draft workflow + environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, graph=workflow_data.get('graph', {}), features=workflow_data.get('../core/app/features', {}), unique_hash=None, - account=account + account=account, + environment_variables=environment_variables, ) workflow_service.publish_workflow( app_model=app, @@ -276,12 +280,15 @@ class AppDslService: unique_hash = None # sync draft workflow + environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=workflow_data.get('graph', {}), features=workflow_data.get('features', {}), unique_hash=unique_hash, - account=account + account=account, + environment_variables=environment_variables, ) return draft_workflow @@ -377,7 +384,7 @@ class AppDslService: return app @classmethod - def _append_workflow_export_data(cls, export_data: dict, app_model: App) -> None: + def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: """ Append workflow export data :param export_data: export data @@ -388,10 +395,7 @@ class AppDslService: if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - export_data['workflow'] = { - "graph": workflow.graph_dict, - "features": workflow.features_dict - } + export_data['workflow'] = workflow.to_dict(include_secret=include_secret) @classmethod def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: diff --git a/api/services/app_service.py b/api/services/app_service.py index 36efde7825..f88e824b07 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -98,7 +98,7 @@ class AppService: model_instance = None if model_instance: - if model_instance.model == default_model_config['model']['name']: + if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: default_model_dict = default_model_config['model'] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) @@ -346,7 +346,7 @@ class AppService: try: provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( ApiToolProvider.id == provider_id - ) + ).first() meta['tool_icons'][tool_name] = json.loads(provider.icon) except: meta['tool_icons'][tool_name] = { diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3d9f1851b7..d5a54ba731 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -688,7 +688,7 @@ class DocumentService: dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', @@ -845,13 +845,17 @@ class DocumentService: 'only_main_content': website_info.get('only_main_content', False), 'mode': 'crawl', } + if len(url) > 255: + document_name = url[:200] + '...' + else: + document_name = url document = DocumentService.build_document( dataset, dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], data_source_info, created_from, position, - account, url, batch + account, document_name, batch ) db.session.add(document) db.session.flush() @@ -1059,7 +1063,7 @@ class DocumentService: retrieval_model = document_data['retrieval_model'] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 9bcf828712..b83e1d8cb7 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,7 +9,7 @@ from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH, + 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, 'reranking_model': { 'reranking_provider_name': '', diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index c684c2862b..4f59b86c12 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -133,6 +133,7 @@ class ModelLoadBalancingService: # move the inherit configuration to the first for i, load_balancing_config in enumerate(load_balancing_configs): if load_balancing_config.name == '__inherit__': + # FIXME: Mutation to loop iterable `load_balancing_configs` during iteration inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 1c1c5be17c..20d21c22a9 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -4,7 +4,6 @@ from os import path from typing import Optional import requests -from flask import current_app from configs import dify_config from constants.languages import languages diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 010d53389a..06b129be69 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -199,7 +199,8 @@ class WorkflowConverter: version='draft', graph=json.dumps(graph), features=json.dumps(features), - created_by=account_id + created_by=account_id, + environment_variables=[], ) db.session.add(workflow) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ea76cfa2e8..cf3f429b02 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,10 +1,12 @@ import json import time +from collections.abc import Sequence from datetime import datetime, timezone from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError @@ -62,11 +64,16 @@ class WorkflowService: return workflow - def sync_draft_workflow(self, app_model: App, - graph: dict, - features: dict, - unique_hash: Optional[str], - account: Account) -> Workflow: + def sync_draft_workflow( + self, + *, + app_model: App, + graph: dict, + features: dict, + unique_hash: Optional[str], + account: Account, + environment_variables: Sequence[Variable], + ) -> Workflow: """ Sync draft workflow :raises WorkflowHashNotEqualError @@ -74,10 +81,8 @@ class WorkflowService: # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) - if workflow: - # validate unique hash - if workflow.unique_hash != unique_hash: - raise WorkflowHashNotEqualError() + if workflow and workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() # validate features structure self.validate_features_structure( @@ -94,7 +99,8 @@ class WorkflowService: version='draft', graph=json.dumps(graph), features=json.dumps(features), - created_by=account.id + created_by=account.id, + environment_variables=environment_variables ) db.session.add(workflow) # update draft workflow if found @@ -103,6 +109,7 @@ class WorkflowService: workflow.features = json.dumps(features) workflow.updated_by = account.id workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow.environment_variables = environment_variables # commit db session changes db.session.commit() @@ -138,7 +145,8 @@ class WorkflowService: version=str(datetime.now(timezone.utc).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, - created_by=account.id + created_by=account.id, + environment_variables=draft_workflow.environment_variables ) # commit db session changes diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index b27274be37..f129d93de8 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -46,14 +46,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - index_node_ids = [segment.index_node_id for segment in segments] + if segments: + index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() end_at = time.perf_counter() logging.info( click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/__init__.py b/api/tests/integration_tests/model_runtime/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py new file mode 100644 index 0000000000..639227e745 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -0,0 +1,19 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider + + +def test_validate_provider_credentials(): + provider = SageMakerProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={} + ) + + provider.validate_provider_credentials( + credentials={} + ) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py new file mode 100644 index 0000000000..c67849dd79 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel + + +def test_validate_credentials(): + model = SageMakerRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + +def test_invoke_model(): + model = SageMakerRerankModel() + + result = model.invoke( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py new file mode 100644 index 0000000000..e817e8f04a --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel + + +def test_validate_credentials(): + model = SageMakerEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3', + credentials={ + } + ) + + model.validate_credentials( + model='bge-m3-embedding', + credentials={ + } + ) + + +def test_invoke_model(): + model = SageMakerEmbeddingModel() + + result = model.invoke( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + +def test_get_num_tokens(): + model = SageMakerEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + ] + ) + + assert num_tokens == 0 diff --git a/api/tests/integration_tests/model_runtime/stepfun/__init__.py b/api/tests/integration_tests/model_runtime/stepfun/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/stepfun/test_llm.py b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py new file mode 100644 index 0000000000..d703147d63 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/stepfun/test_llm.py @@ -0,0 +1,176 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel + + +def test_validate_credentials(): + model = StepfunLargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='step-1-8k', + credentials={ + 'api_key': 'invalid_key' + } + ) + + model.validate_credentials( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + } + ) + +def test_invoke_model(): + model = StepfunLargeLanguageModel() + + response = model.invoke( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + }, + prompt_messages=[ + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.9, + 'top_p': 0.7 + }, + stop=['Hi'], + stream=False, + user="abc-123" + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = StepfunLargeLanguageModel() + + response = model.invoke( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ], + model_parameters={ + 'temperature': 0.9, + 'top_p': 0.7 + }, + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_customizable_model_schema(): + model = StepfunLargeLanguageModel() + + schema = model.get_customizable_model_schema( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + } + ) + assert isinstance(schema, AIModelEntity) + + +def test_invoke_chat_model_with_tools(): + model = StepfunLargeLanguageModel() + + result = model.invoke( + model='step-1-8k', + credentials={ + 'api_key': os.environ.get('STEPFUN_API_KEY') + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content="what's the weather today in Shanghai?", + ) + ], + model_parameters={ + 'temperature': 0.9, + 'max_tokens': 100 + }, + tools=[ + PromptMessageTool( + name='get_weather', + description='Determine weather in my location', + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + ), + PromptMessageTool( + name='get_stock_price', + description='Get the current stock price', + parameters={ + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock symbol" + } + }, + "required": [ + "symbol" + ] + } + ) + ], + stream=False, + user="abc-123" + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + assert len(result.message.tool_calls) > 0 \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index caa643b2c5..d27f96d8ff 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -55,9 +55,9 @@ def test_execute_code(setup_code_executor_mock): ) # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}) - pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) - pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) + pool.add(['1', '123', 'args1'], 1) + pool.add(['1', '123', 'args2'], 2) # execute node result = node.run(pool) @@ -109,9 +109,9 @@ def test_execute_code_output_validator(setup_code_executor_mock): ) # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}) - pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) - pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) + pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) + pool.add(['1', '123', 'args1'], 1) + pool.add(['1', '123', 'args2'], 2) # execute node result = node.run(pool) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 174dd083b5..dc51528136 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -18,9 +18,9 @@ BASIC_NODE_DATA = { } # construct variable pool -pool = VariablePool(system_variables={}, user_inputs={}) -pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1) -pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2) +pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) +pool.add(['a', 'b123', 'args1'], 1) +pool.add(['a', 'b123', 'args2'], 2) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @@ -43,7 +43,6 @@ def test_get(setup_http_mock): 'headers': 'X-Header:123', 'params': 'A:b', 'body': None, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -52,7 +51,6 @@ def test_get(setup_http_mock): data = result.process_data.get('request', '') assert '?A=b' in data - assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data @@ -103,7 +101,6 @@ def test_custom_authorization_header(setup_http_mock): 'headers': 'X-Header:123', 'params': 'A:b', 'body': None, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -113,7 +110,6 @@ def test_custom_authorization_header(setup_http_mock): assert '?A=b' in data assert 'X-Header: 123' in data - assert 'X-Auth: Auth' in data @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @@ -136,7 +132,6 @@ def test_template(setup_http_mock): 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', 'body': None, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -145,7 +140,6 @@ def test_template(setup_http_mock): assert '?A=b' in data assert 'Template=2' in data - assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data assert 'X-Header2: 2' in data @@ -173,7 +167,6 @@ def test_json(setup_http_mock): 'type': 'json', 'data': '{"a": "{{#a.b123.args1#}}"}' }, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -181,7 +174,6 @@ def test_json(setup_http_mock): data = result.process_data.get('request', '') assert '{"a": "1"}' in data - assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data @@ -207,7 +199,6 @@ def test_x_www_form_urlencoded(setup_http_mock): 'type': 'x-www-form-urlencoded', 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' }, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -215,7 +206,6 @@ def test_x_www_form_urlencoded(setup_http_mock): data = result.process_data.get('request', '') assert 'a=1&b=2' in data - assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data @@ -241,7 +231,6 @@ def test_form_data(setup_http_mock): 'type': 'form-data', 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' }, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -252,7 +241,6 @@ def test_form_data(setup_http_mock): assert '1' in data assert 'form-data; name="b"' in data assert '2' in data - assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data @@ -278,14 +266,12 @@ def test_none_data(setup_http_mock): 'type': 'none', 'data': '123123123' }, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) result = node.run(pool) data = result.process_data.get('request', '') - assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data assert '123123123' not in data @@ -305,7 +291,6 @@ def test_mock_404(setup_http_mock): 'body': None, 'params': '', 'headers': 'X-Header:123', - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) @@ -334,7 +319,6 @@ def test_multi_colons_parse(setup_http_mock): 'type': 'form-data', 'data': 'Referer:http://example5.com\nRedirect:http://example6.com' }, - 'mask_authorization_header': False, } }, **BASIC_NODE_DATA) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 199cba0aaa..c5b17083c5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -69,8 +69,8 @@ def test_execute_llm(setup_openai_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) - pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + }, user_inputs={}, environment_variables=[]) + pool.add(['abc', 'output'], 'sunny') credentials = { 'openai_api_key': os.environ.get('OPENAI_API_KEY') @@ -184,8 +184,8 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) - pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny') + }, user_inputs={}, environment_variables=[]) + pool.add(['abc', 'output'], 'sunny') credentials = { 'openai_api_key': os.environ.get('OPENAI_API_KEY') diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 128ff4ed6b..fddcacdf54 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -122,7 +122,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) + }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -180,7 +180,7 @@ def test_instructions(setup_openai_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) + }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -246,7 +246,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) + }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -310,7 +310,7 @@ def test_completion_parameter_extractor(setup_openai_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) + }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -423,7 +423,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) + }, user_inputs={}, environment_variables=[]) result = node.run(pool) diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index f85f95bd3d..dad1f85ebf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -38,9 +38,9 @@ def test_execute_code(setup_code_executor_mock): ) # construct variable pool - pool = VariablePool(system_variables={}, user_inputs={}) - pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) - pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=3) + pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) + pool.add(['1', '123', 'args1'], 1) + pool.add(['1', '123', 'args2'], 3) # execute node result = node.run(pool) diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 941c5b55f4..283d9b8519 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -6,8 +6,8 @@ from models.workflow import WorkflowNodeExecutionStatus def test_tool_variable_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}) - pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value='1+1') + pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) + pool.add(['1', '123', 'args1'], '1+1') node = ToolNode( tenant_id='1', @@ -45,8 +45,8 @@ def test_tool_variable_invoke(): assert result.outputs['files'] == [] def test_tool_mixed_invoke(): - pool = VariablePool(system_variables={}, user_inputs={}) - pool.append_variable(node_id='1', variable_key_list=['args1'], value='1+1') + pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[]) + pool.add(['1', 'args1'], '1+1') node = ToolNode( tenant_id='1', diff --git a/api/tests/unit_tests/app/test_segment.py b/api/tests/unit_tests/app/test_segment.py new file mode 100644 index 0000000000..7ef37ff646 --- /dev/null +++ b/api/tests/unit_tests/app/test_segment.py @@ -0,0 +1,53 @@ +from core.app.segments import SecretVariable, parser +from core.helper import encrypter +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool + + +def test_segment_group_to_text(): + variable_pool = VariablePool( + system_variables={ + SystemVariable('user_id'): 'fake-user-id', + }, + user_inputs={}, + environment_variables=[ + SecretVariable(name='secret_key', value='fake-secret-key'), + ], + ) + variable_pool.add(('node_id', 'custom_query'), 'fake-user-query') + template = ( + 'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.' + ) + segments_group = parser.convert_template(template=template, variable_pool=variable_pool) + + assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.' + assert ( + segments_group.log + == f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}." + ) + + +def test_convert_constant_to_segment_group(): + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=[], + ) + template = 'Hello, world!' + segments_group = parser.convert_template(template=template, variable_pool=variable_pool) + assert segments_group.text == 'Hello, world!' + assert segments_group.log == 'Hello, world!' + + +def test_convert_variable_to_segment_group(): + variable_pool = VariablePool( + system_variables={ + SystemVariable('user_id'): 'fake-user-id', + }, + user_inputs={}, + environment_variables=[], + ) + template = '{{#sys.user_id#}}' + segments_group = parser.convert_template(template=template, variable_pool=variable_pool) + assert segments_group.text == 'fake-user-id' + assert segments_group.log == 'fake-user-id' diff --git a/api/tests/unit_tests/app/test_variables.py b/api/tests/unit_tests/app/test_variables.py new file mode 100644 index 0000000000..65db88a4a8 --- /dev/null +++ b/api/tests/unit_tests/app/test_variables.py @@ -0,0 +1,91 @@ +import pytest +from pydantic import ValidationError + +from core.app.segments import ( + FloatVariable, + IntegerVariable, + SecretVariable, + SegmentType, + StringVariable, + factory, +) + + +def test_string_variable(): + test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, StringVariable) + + +def test_integer_variable(): + test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, IntegerVariable) + + +def test_float_variable(): + test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, FloatVariable) + + +def test_secret_variable(): + test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'} + result = factory.build_variable_from_mapping(test_data) + assert isinstance(result, SecretVariable) + + +def test_invalid_value_type(): + test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'} + with pytest.raises(ValueError): + factory.build_variable_from_mapping(test_data) + + +def test_frozen_variables(): + var = StringVariable(name='text', value='text') + with pytest.raises(ValidationError): + var.value = 'new value' + + int_var = IntegerVariable(name='integer', value=42) + with pytest.raises(ValidationError): + int_var.value = 100 + + float_var = FloatVariable(name='float', value=3.14) + with pytest.raises(ValidationError): + float_var.value = 2.718 + + secret_var = SecretVariable(name='secret', value='secret_value') + with pytest.raises(ValidationError): + secret_var.value = 'new_secret_value' + + +def test_variable_value_type_immutable(): + with pytest.raises(ValidationError): + StringVariable(value_type=SegmentType.ARRAY, name='text', value='text') + + with pytest.raises(ValidationError): + StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'}) + + var = IntegerVariable(name='integer', value=42) + with pytest.raises(ValidationError): + IntegerVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value) + + var = FloatVariable(name='float', value=3.14) + with pytest.raises(ValidationError): + FloatVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value) + + var = SecretVariable(name='secret', value='secret_value') + with pytest.raises(ValidationError): + SecretVariable(value_type=SegmentType.ARRAY, name=var.name, value=var.value) + + +def test_build_a_blank_string(): + result = factory.build_variable_from_mapping( + { + 'value_type': 'string', + 'name': 'blank', + 'value': '', + } + ) + assert isinstance(result, StringVariable) + assert result.value == '' diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 50bb2b75ac..949a5a1769 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -1,3 +1,4 @@ +import os from textwrap import dedent import pytest @@ -48,7 +49,9 @@ def test_dify_config(example_env_file): # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. def test_flask_configs(example_env_file): flask_app = Flask('app') - flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) + # clear system environment variables + os.environ.clear() + flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 376d3f6521..f9e0989868 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -56,9 +56,9 @@ def test_execute_answer(): pool = VariablePool(system_variables={ SystemVariable.FILES: [], SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) - pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny') - pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.') + }, user_inputs={}, environment_variables=[]) + pool.add(['start', 'weather'], 'sunny') + pool.add(['llm', 'text'], 'You are a helpful AI.') node = AnswerNode( graph_init_params=init_params, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index aeed00f359..c1ebafa968 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -170,6 +170,29 @@ def test_execute_if_else_result_true(): } ) + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}, environment_variables=[]) + pool.add(['start', 'array_contains'], ['ab', 'def']) + pool.add(['start', 'array_not_contains'], ['ac', 'def']) + pool.add(['start', 'contains'], 'cabcde') + pool.add(['start', 'not_contains'], 'zacde') + pool.add(['start', 'start_with'], 'abc') + pool.add(['start', 'end_with'], 'zzab') + pool.add(['start', 'is'], 'ab') + pool.add(['start', 'is_not'], 'aab') + pool.add(['start', 'empty'], '') + pool.add(['start', 'not_empty'], 'aaa') + pool.add(['start', 'equals'], 22) + pool.add(['start', 'not_equals'], 23) + pool.add(['start', 'greater_than'], 23) + pool.add(['start', 'less_than'], 21) + pool.add(['start', 'greater_than_or_equal'], 22) + pool.add(['start', 'less_than_or_equal'], 21) + pool.add(['start', 'not_null'], '1212') + # Mock db.session.close() db.session.close = MagicMock() @@ -214,9 +237,9 @@ def test_execute_if_else_result_false(): pool = VariablePool(system_variables={ SystemVariable.FILES: [], SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) - pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def']) - pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def']) + }, user_inputs={}, environment_variables=[]) + pool.add(['start', 'array_contains'], ['1ab', 'def']) + pool.add(['start', 'array_not_contains'], ['ab', 'def']) # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py new file mode 100644 index 0000000000..facea34b5b --- /dev/null +++ b/api/tests/unit_tests/models/test_workflow.py @@ -0,0 +1,95 @@ +from unittest import mock +from uuid import uuid4 + +import contexts +from constants import HIDDEN_VALUE +from core.app.segments import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from models.workflow import Workflow + + +def test_environment_variables(): + contexts.tenant_id.set('tenant_id') + + # Create a Workflow instance + workflow = Workflow() + + # Create some EnvironmentVariable instances + variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) + variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) + variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) + variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + + with ( + mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), + mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + ): + # Set the environment_variables property of the Workflow instance + variables = [variable1, variable2, variable3, variable4] + workflow.environment_variables = variables + + # Get the environment_variables property and assert its value + assert workflow.environment_variables == variables + + +def test_update_environment_variables(): + contexts.tenant_id.set('tenant_id') + + # Create a Workflow instance + workflow = Workflow() + + # Create some EnvironmentVariable instances + variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) + variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())}) + variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())}) + variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())}) + + with ( + mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), + mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + ): + variables = [variable1, variable2, variable3, variable4] + + # Set the environment_variables property of the Workflow instance + workflow.environment_variables = variables + assert workflow.environment_variables == [variable1, variable2, variable3, variable4] + + # Update the name of variable3 and keep the value as it is + variables[2] = variable3.model_copy( + update={ + 'name': 'new name', + 'value': HIDDEN_VALUE, + } + ) + + workflow.environment_variables = variables + assert workflow.environment_variables[2].name == 'new name' + assert workflow.environment_variables[2].value == variable3.value + + +def test_to_dict(): + contexts.tenant_id.set('tenant_id') + + # Create a Workflow instance + workflow = Workflow() + workflow.graph = '{}' + workflow.features = '{}' + + # Create some EnvironmentVariable instances + + with ( + mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'), + mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'), + ): + # Set the environment_variables property of the Workflow instance + workflow.environment_variables = [ + SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}), + StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}), + ] + + workflow_dict = workflow.to_dict() + assert workflow_dict['environment_variables'][0]['value'] == '' + assert workflow_dict['environment_variables'][1]['value'] == 'text' + + workflow_dict = workflow.to_dict(include_secret=True) + assert workflow_dict['environment_variables'][0]['value'] == 'secret' + assert workflow_dict['environment_variables'][1]['value'] == 'text' diff --git a/docker/.env.example b/docker/.env.example index 4f7e13e823..2f8ec358f4 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -124,6 +124,10 @@ GUNICORN_TIMEOUT=360 # The number of Celery workers. The default is 1, and can be set as needed. CELERY_WORKER_AMOUNT= +# API Tool configuration +API_TOOL_DEFAULT_CONNECT_TIMEOUT=10 +API_TOOL_DEFAULT_READ_TIMEOUT=60 + # ------------------------------ # Database Configuration # The database uses PostgreSQL. Please use the public schema. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index cffaa5a6a3..30fdf16b17 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -22,6 +22,8 @@ x-shared-env: &shared-api-worker-env CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} + API_TOOL_DEFAULT_CONNECT_TIMEOUT: ${API_TOOL_DEFAULT_CONNECT_TIMEOUT:-10} + API_TOOL_DEFAULT_READ_TIMEOUT: ${API_TOOL_DEFAULT_READ_TIMEOUT:-60} DB_USERNAME: ${DB_USERNAME:-postgres} DB_PASSWORD: ${DB_PASSWORD:-difyai123456} DB_HOST: ${DB_HOST:-db} diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index 53880c1000..22d4e6189a 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -2,9 +2,9 @@ import requests class DifyClient: - def __init__(self, api_key): + def __init__(self, api_key, base_url: str = 'https://api.dify.ai/v1'): self.api_key = api_key - self.base_url = "https://api.dify.ai/v1" + self.base_url = base_url def _send_request(self, method, endpoint, json=None, params=None, stream=False): headers = { diff --git a/web/.vscode/extensions.json b/web/.vscode/extensions.json new file mode 100644 index 0000000000..d7680d74a5 --- /dev/null +++ b/web/.vscode/extensions.json @@ -0,0 +1,6 @@ +{ + "recommendations": [ + "bradlc.vscode-tailwindcss", + "firsttris.vscode-jest-runner" + ] +} diff --git a/web/README.md b/web/README.md index 2ecba1c8ff..867d822e27 100644 --- a/web/README.md +++ b/web/README.md @@ -74,6 +74,25 @@ npm run start --port=3001 --host=0.0.0.0 If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. +## Test + +We start to use [Jest](https://jestjs.io/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. + +You can create a test file with a suffix of `.spec` beside the file that to be tested. For example, if you want to test a file named `util.ts`. The test file name should be `util.spec.ts`. + +Run test: + +```bash +npm run test +``` + +If you are not familiar with writing tests, here is some code to refer to: +* [classnames.spec.ts](./utils/classnames.spec.ts) +* [index.spec.tsx](./app/components/base/button/index.spec.tsx) + + + + ## Documentation Visit to view the full documentation. diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx index 86bee98bcd..09569df8bf 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -40,7 +40,7 @@ const AppDetailLayout: FC = (props) => { const pathname = usePathname() const media = useBreakpoints() const isMobile = media === MediaType.mobile - const { isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() + const { isCurrentWorkspaceEditor } = useAppContext() const { appDetail, setAppDetail, setAppSiderbarExpand } = useStore(useShallow(state => ({ appDetail: state.appDetail, setAppDetail: state.setAppDetail, @@ -53,7 +53,7 @@ const AppDetailLayout: FC = (props) => { selectedIcon: NavIcon }>>([]) - const getNavigations = useCallback((appId: string, isCurrentWorkspaceManager: boolean, isCurrentWorkspaceEditor: boolean, mode: string) => { + const getNavigations = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: string) => { const navs = [ ...(isCurrentWorkspaceEditor ? [{ @@ -70,7 +70,7 @@ const AppDetailLayout: FC = (props) => { icon: RiTerminalBoxLine, selectedIcon: RiTerminalBoxFill, }, - ...(isCurrentWorkspaceManager + ...(isCurrentWorkspaceEditor ? [{ name: mode !== 'workflow' ? t('common.appMenus.logAndAnn') @@ -115,13 +115,13 @@ const AppDetailLayout: FC = (props) => { } else { setAppDetail(res) - setNavigation(getNavigations(appId, isCurrentWorkspaceManager, isCurrentWorkspaceEditor, res.mode)) + setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) } }).catch((e: any) => { if (e.status === 404) router.replace('/apps') }) - }, [appId, isCurrentWorkspaceManager, isCurrentWorkspaceEditor]) + }, [appId, isCurrentWorkspaceEditor]) useUnmount(() => { setAppDetail() diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index 53b31af7f0..34279f816e 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -28,6 +28,9 @@ import EditAppModal from '@/app/components/explore/create-app-modal' import SwitchAppModal from '@/app/components/app/switch-app-modal' import type { Tag } from '@/app/components/base/tag-management/constant' import TagSelector from '@/app/components/base/tag-management/selector' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal' +import { fetchWorkflowDraft } from '@/service/workflow' export type AppCardProps = { app: App @@ -50,6 +53,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const [showDuplicateModal, setShowDuplicateModal] = useState(false) const [showSwitchModal, setShowSwitchModal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) + const [secretEnvList, setSecretEnvList] = useState([]) const onConfirmDelete = useCallback(async () => { try { @@ -123,9 +127,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } } - const onExport = async () => { + const onExport = async (include = false) => { try { - const { data } = await exportAppConfig(app.id) + const { data } = await exportAppConfig({ + appID: app.id, + include, + }) const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) a.href = URL.createObjectURL(file) @@ -137,6 +144,25 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } } + const exportCheck = async () => { + if (app.mode !== 'workflow' && app.mode !== 'advanced-chat') { + onExport() + return + } + try { + const workflowDraft = await fetchWorkflowDraft(`/apps/${app.id}/workflows/draft`) + const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') + if (list.length === 0) { + onExport() + return + } + setSecretEnvList(list) + } + catch (e) { + notify({ type: 'error', message: t('app.exportFailed') }) + } + } + const onSwitch = () => { if (onRefresh) onRefresh() @@ -164,7 +190,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { e.stopPropagation() props.onClick?.() e.preventDefault() - onExport() + exportCheck() } const onClickSwitch = async (e: React.MouseEvent) => { e.stopPropagation() @@ -371,6 +397,13 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { onCancel={() => setShowConfirmDelete(false)} /> )} + {secretEnvList.length > 0 && ( + setSecretEnvList([])} + /> + )} ) } diff --git a/web/app/(commonLayout)/tools/page.tsx b/web/app/(commonLayout)/tools/page.tsx index 4e64d8c0df..1b08d54ba3 100644 --- a/web/app/(commonLayout)/tools/page.tsx +++ b/web/app/(commonLayout)/tools/page.tsx @@ -12,15 +12,16 @@ const Layout: FC = () => { const { isCurrentWorkspaceDatasetOperator } = useAppContext() useEffect(() => { - document.title = `${t('tools.title')} - Dify` + if (typeof window !== 'undefined') + document.title = `${t('tools.title')} - Dify` if (isCurrentWorkspaceDatasetOperator) return router.replace('/datasets') - }, []) + }, [isCurrentWorkspaceDatasetOperator, router, t]) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) return router.replace('/datasets') - }, [isCurrentWorkspaceDatasetOperator]) + }, [isCurrentWorkspaceDatasetOperator, router]) return } diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index c931afbe7f..ef37ff3c78 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -28,6 +28,9 @@ import type { CreateAppModalProps } from '@/app/components/explore/create-app-mo import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { getRedirection } from '@/utils/app-redirection' import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal' +import { fetchWorkflowDraft } from '@/service/workflow' export type IAppInfoProps = { expand: boolean @@ -47,6 +50,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const [showSwitchTip, setShowSwitchTip] = useState('') const [showSwitchModal, setShowSwitchModal] = useState(false) const [showImportDSLModal, setShowImportDSLModal] = useState(false) + const [secretEnvList, setSecretEnvList] = useState([]) const mutateApps = useContextSelector( AppsContext, @@ -108,11 +112,14 @@ const AppInfo = ({ expand }: IAppInfoProps) => { } } - const onExport = async () => { + const onExport = async (include = false) => { if (!appDetail) return try { - const { data } = await exportAppConfig(appDetail.id) + const { data } = await exportAppConfig({ + appID: appDetail.id, + include, + }) const a = document.createElement('a') const file = new Blob([data], { type: 'application/yaml' }) a.href = URL.createObjectURL(file) @@ -124,6 +131,27 @@ const AppInfo = ({ expand }: IAppInfoProps) => { } } + const exportCheck = async () => { + if (!appDetail) + return + if (appDetail.mode !== 'workflow' && appDetail.mode !== 'advanced-chat') { + onExport() + return + } + try { + const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) + const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') + if (list.length === 0) { + onExport() + return + } + setSecretEnvList(list) + } + catch (e) { + notify({ type: 'error', message: t('app.exportFailed') }) + } + } + const onConfirmDelete = useCallback(async () => { if (!appDetail) return @@ -314,7 +342,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { )} -
+
{t('app.export')}
{ @@ -403,14 +431,19 @@ const AppInfo = ({ expand }: IAppInfoProps) => { onCancel={() => setShowConfirmDelete(false)} /> )} - { - showImportDSLModal && ( - setShowImportDSLModal(false)} - onBackup={onExport} - /> - ) - } + {showImportDSLModal && ( + setShowImportDSLModal(false)} + onBackup={onExport} + /> + )} + {secretEnvList.length > 0 && ( + setSecretEnvList([])} + /> + )}
) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index d7e9856ce4..e971274a71 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -119,11 +119,11 @@ const AppPublisher = ({ diff --git a/web/app/components/app/configuration/config-prompt/index.tsx b/web/app/components/app/configuration/config-prompt/index.tsx index bea4a9e455..7e40fdc84e 100644 --- a/web/app/components/app/configuration/config-prompt/index.tsx +++ b/web/app/components/app/configuration/config-prompt/index.tsx @@ -19,6 +19,9 @@ export type IPromptProps = { promptTemplate: string promptVariables: PromptVariable[] readonly?: boolean + noTitle?: boolean + gradientBorder?: boolean + editorHeight?: number onChange?: (prompt: string, promptVariables: PromptVariable[]) => void } @@ -26,7 +29,10 @@ const Prompt: FC = ({ mode, promptTemplate, promptVariables, + noTitle, + gradientBorder, readonly = false, + editorHeight, onChange, }) => { const { t } = useTranslation() @@ -99,6 +105,9 @@ const Prompt: FC = ({ promptVariables={promptVariables} readonly={readonly} onChange={onChange} + noTitle={noTitle} + gradientBorder={gradientBorder} + editorHeight={editorHeight} /> ) } diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index a15f538227..b0a140fc97 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -28,6 +28,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/config-var' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' export type ISimplePromptInput = { mode: AppType @@ -35,6 +36,9 @@ export type ISimplePromptInput = { promptVariables: PromptVariable[] readonly?: boolean onChange?: (promp: string, promptVariables: PromptVariable[]) => void + noTitle?: boolean + gradientBorder?: boolean + editorHeight?: number } const Prompt: FC = ({ @@ -43,8 +47,14 @@ const Prompt: FC = ({ promptVariables, readonly = false, onChange, + noTitle, + gradientBorder, + editorHeight: initEditorHeight, }) => { const { t } = useTranslation() + const media = useBreakpoints() + const isMobile = media === MediaType.mobile + const { eventEmitter } = useEventEmitterContextContext() const { modelConfig, @@ -116,6 +126,11 @@ const Prompt: FC = ({ const [showAutomatic, { setTrue: showAutomaticTrue, setFalse: showAutomaticFalse }] = useBoolean(false) const handleAutomaticRes = (res: AutomaticRes) => { + // put eventEmitter in first place to prevent overwrite the configs.prompt_variables.But another problem is that prompt won't hight the prompt_variables. + eventEmitter?.emit({ + type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, + payload: res.prompt, + } as any) const newModelConfig = produce(modelConfig, (draft) => { draft.configs.prompt_template = res.prompt draft.configs.prompt_variables = res.variables.map(key => ({ key, name: key, type: 'string', required: true })) @@ -125,36 +140,35 @@ const Prompt: FC = ({ if (mode !== AppType.completion) setIntroduction(res.opening_statement) showAutomaticFalse() - eventEmitter?.emit({ - type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, - payload: res.prompt, - } as any) } - const minHeight = 228 + const minHeight = initEditorHeight || 228 const [editorHeight, setEditorHeight] = useState(minHeight) return ( -
+
-
-
-
{mode !== AppType.completion ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
- {!readonly && ( - - {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - - )} + {!noTitle && ( +
+
+
{mode !== AppType.completion ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
+ {!readonly && ( + + {t('appDebug.promptTip')} +
} + selector='config-prompt-tooltip'> + + + )} +
+
+ {!isAgent && !readonly && !isMobile && ( + + )} +
-
- {!isAgent && !readonly && ( - - )} -
-
+ )} + = ({ onBlur={() => { handleChange(promptTemplate, getVars(promptTemplate)) }} + editable={!readonly} />
diff --git a/web/app/components/app/configuration/config/automatic/automatic-btn.tsx b/web/app/components/app/configuration/config/automatic/automatic-btn.tsx index 40a9b9d799..f70976082d 100644 --- a/web/app/components/app/configuration/config/automatic/automatic-btn.tsx +++ b/web/app/components/app/configuration/config/automatic/automatic-btn.tsx @@ -2,29 +2,21 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import { Generator } from '@/app/components/base/icons/src/vender/other' export type IAutomaticBtnProps = { onClick: () => void } - -const leftIcon = ( - - - - - - -) const AutomaticBtn: FC = ({ onClick, }) => { const { t } = useTranslation() return ( -
- {leftIcon} + {t('appDebug.operation.automatic')}
) diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index fa58253cac..13cf857edf 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -1,8 +1,20 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' +import { + RiDatabase2Line, + RiFileExcel2Line, + RiGitCommitLine, + RiNewspaperLine, + RiPresentationLine, + RiRoadMapLine, + RiTerminalBoxLine, + RiTranslate, + RiUser2Line, +} from '@remixicon/react' +import s from './style.module.css' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' import Toast from '@/app/components/base/toast' @@ -14,57 +26,97 @@ import OpeningStatement from '@/app/components/app/configuration/features/chat-g import GroupName from '@/app/components/app/configuration/base/group-name' import Loading from '@/app/components/base/loading' import Confirm from '@/app/components/base/confirm' + // type import type { AutomaticRes } from '@/service/debug' -import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' - -const noDataIcon = ( - - - -) +import { Generator } from '@/app/components/base/icons/src/vender/other' export type IGetAutomaticResProps = { mode: AppType isShow: boolean onClose: () => void onFinished: (res: AutomaticRes) => void + isInLLMNode?: boolean } -const genIcon = ( - - - - - -) +const TryLabel: FC<{ + Icon: any + text: string + onClick: () => void +}> = ({ Icon, text, onClick }) => { + return ( +
+ +
{text}
+
+ ) +} const GetAutomaticRes: FC = ({ mode, isShow, onClose, - // appId, + isInLLMNode, onFinished, }) => { const { t } = useTranslation() - const media = useBreakpoints() - const isMobile = media === MediaType.mobile + const tryList = [ + { + icon: RiTerminalBoxLine, + key: 'pythonDebugger', + }, + { + icon: RiTranslate, + key: 'translation', + }, + { + icon: RiPresentationLine, + key: 'meetingTakeaways', + }, + { + icon: RiNewspaperLine, + key: 'writingsPolisher', + }, + { + icon: RiUser2Line, + key: 'professionalAnalyst', + }, + { + icon: RiFileExcel2Line, + key: 'excelFormulaExpert', + }, + { + icon: RiRoadMapLine, + key: 'travelPlanning', + }, + { + icon: RiDatabase2Line, + key: 'SQLSorcerer', + }, + { + icon: RiGitCommitLine, + key: 'GitGud', + }, + ] - const [audiences, setAudiences] = React.useState('') - const [hopingToSolve, setHopingToSolve] = React.useState('') - const isValid = () => { - if (audiences.trim() === '') { - Toast.notify({ - type: 'error', - message: t('appDebug.automatic.audiencesRequired'), - }) - return false + const [instruction, setInstruction] = React.useState('') + const handleChooseTemplate = useCallback((key: string) => { + return () => { + const template = t(`appDebug.generate.template.${key}.instruction`) + setInstruction(template) } - if (hopingToSolve.trim() === '') { + }, [t]) + const isValid = () => { + if (instruction.trim() === '') { Toast.notify({ type: 'error', - message: t('appDebug.automatic.problemRequired'), + message: t('common.errorMsg.fieldRequired', { + field: t('appDebug.generate.instruction'), + }), }) return false } @@ -76,14 +128,17 @@ const GetAutomaticRes: FC = ({ const renderLoading = (
-
{t('appDebug.automatic.loading')}
+
{t('appDebug.generate.loading')}
) const renderNoData = (
- {noDataIcon} -
{t('appDebug.automatic.noData')}
+ +
+
{t('appDebug.generate.noDataLine1')}
+
{t('appDebug.generate.noDataLine2')}
+
) @@ -95,8 +150,7 @@ const GetAutomaticRes: FC = ({ setLoadingTrue() try { const res = await generateRule({ - audiences, - hoping_to_solve: hopingToSolve, + instruction, }) setRes(res) } @@ -107,24 +161,7 @@ const GetAutomaticRes: FC = ({ const [showConfirmOverwrite, setShowConfirmOverwrite] = React.useState(false) - const isShowAutoPromptInput = () => { - if (isMobile) { - // hide prompt panel on mobile if it is loading or has had result - if (isLoading || res) - return false - return true - } - - // always display prompt panel on desktop mode - return true - } - const isShowAutoPromptResPlaceholder = () => { - if (isMobile) { - // hide placeholder panel on mobile - return false - } - return !isLoading && !res } @@ -132,75 +169,96 @@ const GetAutomaticRes: FC = ({ -
- {isShowAutoPromptInput() &&
-
-
{t('appDebug.automatic.title')}
-
{t('appDebug.automatic.description')}
+
+
+
+
{t('appDebug.generate.title')}
+
{t('appDebug.generate.description')}
+
+
+
+
{t('appDebug.generate.tryIt')}
+
+
+
+ {tryList.map(item => ( + + ))} +
{/* inputs */} -
-
-
{t('appDebug.automatic.intendedAudience')}
- setAudiences(e.target.value)} /> -
-
-
{t('appDebug.automatic.solveProblem')}
-