diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index f6092c8633..d681dc6627 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -45,6 +45,10 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example + - name: Ruff formatter check + if: steps.changed-files.outputs.any_changed == 'true' + run: poetry run -C api ruff format --check ./api + - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." diff --git a/.gitignore b/.gitignore index c52b9d8bbf..22aba65364 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,4 @@ pyrightconfig.json api/.vscode .idea/ +.vscode \ No newline at end of file diff --git a/api/.env.example b/api/.env.example index 775149f8fd..f81675fd53 100644 --- a/api/.env.example +++ b/api/.env.example @@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration -CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= diff --git a/.idea/icon.png b/api/.idea/icon.png similarity index 100% rename from .idea/icon.png rename to api/.idea/icon.png diff --git a/.idea/vcs.xml b/api/.idea/vcs.xml similarity index 100% rename from .idea/vcs.xml rename to api/.idea/vcs.xml diff --git a/.vscode/launch.json b/api/.vscode/launch.json.example similarity index 83% rename from .vscode/launch.json rename to api/.vscode/launch.json.example index e4eb6aef93..e9f8e42dd5 100644 --- a/.vscode/launch.json +++ b/api/.vscode/launch.json.example @@ -5,8 +5,8 @@ "name": "Python: Flask", "type": "debugpy", "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", "envFile": ".env", "module": "flask", "justMyCode": true, @@ -18,15 +18,15 @@ "args": [ "run", "--host=0.0.0.0", - "--port=5001", + "--port=5001" ] }, { "name": "Python: Celery", "type": "debugpy", "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", "module": "celery", "justMyCode": true, "envFile": ".env", diff --git a/api/app.py b/api/app.py index 50441cb81d..ad219ca0d6 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,6 @@ import os -if os.environ.get("DEBUG", "false").lower() != 'true': +if os.environ.get("DEBUG", "false").lower() != "true": from gevent import monkey monkey.patch_all() @@ -57,7 +57,7 @@ warnings.simplefilter("ignore", ResourceWarning) if os.name == "nt": os.system('tzutil /s "UTC"') else: - os.environ['TZ'] = 'UTC' + os.environ["TZ"] = "UTC" time.tzset() @@ -70,13 +70,14 @@ class DifyApp(Flask): # ------------- -config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first +config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first # ---------------------------- # Application Factory Function # ---------------------------- + def create_flask_app_with_configs() -> Flask: """ create a raw flask app @@ -92,7 +93,7 @@ def create_flask_app_with_configs() -> Flask: elif isinstance(value, int | float | bool): os.environ[key] = str(value) elif value is None: - os.environ[key] = '' + os.environ[key] = "" return dify_app @@ -100,10 +101,10 @@ def create_flask_app_with_configs() -> Flask: def create_app() -> Flask: app = create_flask_app_with_configs() - app.secret_key = app.config['SECRET_KEY'] + app.secret_key = app.config["SECRET_KEY"] log_handlers = None - log_file = app.config.get('LOG_FILE') + log_file = app.config.get("LOG_FILE") if log_file: log_dir = os.path.dirname(log_file) os.makedirs(log_dir, exist_ok=True) @@ -111,23 +112,24 @@ def create_app() -> Flask: RotatingFileHandler( filename=log_file, maxBytes=1024 * 1024 * 1024, - backupCount=5 + backupCount=5, ), - logging.StreamHandler(sys.stdout) + logging.StreamHandler(sys.stdout), ] logging.basicConfig( - level=app.config.get('LOG_LEVEL'), - format=app.config.get('LOG_FORMAT'), - datefmt=app.config.get('LOG_DATEFORMAT'), + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), handlers=log_handlers, - force=True + force=True, ) - log_tz = app.config.get('LOG_TZ') + log_tz = app.config.get("LOG_TZ") if log_tz: from datetime import datetime import pytz + timezone = pytz.timezone(log_tz) def time_converter(seconds): @@ -162,24 +164,24 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ['console', 'inner_api']: + if request.blueprint not in ["console", "inner_api"]: return None # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get('Authorization', '') + auth_header = request.headers.get("Authorization", "") if not auth_header: - auth_token = request.args.get('_token') + auth_token = request.args.get("_token") if not auth_token: - raise Unauthorized('Invalid Authorization token.') + raise Unauthorized("Invalid Authorization token.") else: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(auth_token) - user_id = decoded.get('user_id') + user_id = decoded.get("user_id") account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) if account: @@ -190,10 +192,11 @@ def load_user_from_request(request_from_flask_login): @login_manager.unauthorized_handler def unauthorized_handler(): """Handle unauthorized requests.""" - return Response(json.dumps({ - 'code': 'unauthorized', - 'message': "Unauthorized." - }), status=401, content_type="application/json") + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) # register blueprint routers @@ -204,38 +207,36 @@ def register_blueprints(app): from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp - CORS(service_api_bp, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) app.register_blueprint(service_api_bp) - CORS(web_bp, - resources={ - r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(web_bp) - CORS(console_app_bp, - resources={ - r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}}, - supports_credentials=True, - allow_headers=['Content-Type', 'Authorization'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], - expose_headers=['X-Version', 'X-Env'] - ) + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) app.register_blueprint(console_app_bp) - CORS(files_bp, - allow_headers=['Content-Type'], - methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'] - ) + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) @@ -245,29 +246,29 @@ def register_blueprints(app): app = create_app() celery = app.extensions["celery"] -if app.config.get('TESTING'): +if app.config.get("TESTING"): print("App is running in TESTING mode") @app.after_request def after_request(response): """Add Version headers to the response.""" - response.set_cookie('remember_token', '', expires=0) - response.headers.add('X-Version', app.config['CURRENT_VERSION']) - response.headers.add('X-Env', app.config['DEPLOY_ENV']) + response.set_cookie("remember_token", "", expires=0) + response.headers.add("X-Version", app.config["CURRENT_VERSION"]) + response.headers.add("X-Env", app.config["DEPLOY_ENV"]) return response -@app.route('/health') +@app.route("/health") def health(): - return Response(json.dumps({ - 'pid': os.getpid(), - 'status': 'ok', - 'version': app.config['CURRENT_VERSION'] - }), status=200, content_type="application/json") + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + status=200, + content_type="application/json", + ) -@app.route('/threads') +@app.route("/threads") def threads(): num_threads = threading.active_count() threads = threading.enumerate() @@ -278,32 +279,34 @@ def threads(): thread_id = thread.ident is_alive = thread.is_alive() - thread_list.append({ - 'name': thread_name, - 'id': thread_id, - 'is_alive': is_alive - }) + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) return { - 'pid': os.getpid(), - 'thread_num': num_threads, - 'threads': thread_list + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, } -@app.route('/db-pool-stat') +@app.route("/db-pool-stat") def pool_stat(): engine = db.engine return { - 'pid': os.getpid(), - 'pool_size': engine.pool.size(), - 'checked_in_connections': engine.pool.checkedin(), - 'checked_out_connections': engine.pool.checkedout(), - 'overflow_connections': engine.pool.overflow(), - 'connection_timeout': engine.pool.timeout(), - 'recycle_time': db.engine.pool._recycle + "pid": os.getpid(), + "pool_size": engine.pool.size(), + "checked_in_connections": engine.pool.checkedin(), + "checked_out_connections": engine.pool.checkedout(), + "overflow_connections": engine.pool.overflow(), + "connection_timeout": engine.pool.timeout(), + "recycle_time": db.engine.pool._recycle, } -if __name__ == '__main__': - app.run(host='0.0.0.0', port=5001) +if __name__ == "__main__": + app.run(host="0.0.0.0", port=5001) diff --git a/api/commands.py b/api/commands.py index 82a32f0f5b..41f1a6444c 100644 --- a/api/commands.py +++ b/api/commands.py @@ -27,32 +27,29 @@ from models.provider import Provider, ProviderModel from services.account_service import RegisterService, TenantService -@click.command('reset-password', help='Reset the account password.') -@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') -@click.option('--new-password', prompt=True, help='the new password.') -@click.option('--password-confirm', prompt=True, help='the new password confirm.') +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset") +@click.option("--new-password", prompt=True, help="the new password.") +@click.option("--password-confirm", prompt=True, help="the new password confirm.") def reset_password(email, new_password, password_confirm): """ Reset password of owner account Only available in SELF_HOSTED mode """ if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + click.echo(click.style("sorry. The two passwords do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo( - click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red')) + click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) return # generate password salt @@ -65,80 +62,87 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style('Congratulations! Password has been reset.', fg='green')) + click.echo(click.style("Congratulations! Password has been reset.", fg="green")) -@click.command('reset-email', help='Reset the account email.') -@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') -@click.option('--new-email', prompt=True, help='the new email.') -@click.option('--email-confirm', prompt=True, help='the new email confirm.') +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset") +@click.option("--new-email", prompt=True, help="the new email.") +@click.option("--email-confirm", prompt=True, help="the new email confirm.") def reset_email(email, new_email, email_confirm): """ Replace account email :return: """ if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) return - account = db.session.query(Account). \ - filter(Account.email == email). \ - one_or_none() + account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) return try: email_validate(new_email) except: - click.echo( - click.style('sorry. {} is not a valid email. '.format(email), fg='red')) + click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) return account.email = new_email db.session.commit() - click.echo(click.style('Congratulations!, email has been reset.', fg='green')) + click.echo(click.style("Congratulations!, email has been reset.", fg="green")) -@click.command('reset-encrypt-key-pair', help='Reset the asymmetric key pair of workspace for encrypt LLM credentials. ' - 'After the reset, all LLM credentials will become invalid, ' - 'requiring re-entry.' - 'Only support SELF_HOSTED mode.') -@click.confirmation_option(prompt=click.style('Are you sure you want to reset encrypt key pair?' - ' this operation cannot be rolled back!', fg='red')) +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" + ) +) def reset_encrypt_key_pair(): """ Reset the encrypted key pair of workspace for encrypt LLM credentials. After the reset, all LLM credentials will become invalid, requiring re-entry. Only support SELF_HOSTED mode. """ - if dify_config.EDITION != 'SELF_HOSTED': - click.echo(click.style('Sorry, only support SELF_HOSTED mode.', fg='red')) + if dify_config.EDITION != "SELF_HOSTED": + click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo(click.style('Sorry, no workspace found. Please enter /install to initialize.', fg='red')) + click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red")) return tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter(Provider.provider_type == 'custom', Provider.tenant_id == tenant.id).delete() + db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() - click.echo(click.style('Congratulations! ' - 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) + click.echo( + click.style( + "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + fg="green", + ) + ) -@click.command('vdb-migrate', help='migrate vector db.') -@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.') +@click.command("vdb-migrate", help="migrate vector db.") +@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ['knowledge', 'all']: + if scope in ["knowledge", "all"]: migrate_knowledge_vector_database() - if scope in ['annotation', 'all']: + if scope in ["annotation", "all"]: migrate_annotation_vector_database() @@ -146,7 +150,7 @@ def migrate_annotation_vector_database(): """ Migrate annotation datas to target vector database . """ - click.echo(click.style('Start migrate annotation data.', fg='green')) + click.echo(click.style("Start migrate annotation data.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -154,98 +158,103 @@ def migrate_annotation_vector_database(): while True: try: # get apps info - apps = db.session.query(App).filter( - App.status == 'normal' - ).order_by(App.created_at.desc()).paginate(page=page, per_page=50) + apps = ( + db.session.query(App) + .filter(App.status == "normal") + .order_by(App.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for app in apps: total_count = total_count + 1 - click.echo(f'Processing the {total_count} app {app.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create app annotation index: {}'.format(app.id)) - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app.id - ).first() + click.echo("Create app annotation index: {}".format(app.id)) + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() + ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo('App annotation setting is disabled: {}'.format(app.id)) + click.echo("App annotation setting is disabled: {}".format(app.id)) continue # get dataset_collection_binding info - dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter( - DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id - ).first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) + .first() + ) if not dataset_collection_binding: - click.echo('App annotation collection binding is not exist: {}'.format(app.id)) + click.echo("App annotation collection binding is not exist: {}".format(app.id)) continue annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) documents = [] if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app.id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) click.echo(f"Start to migrate annotation, app_id: {app.id}.") try: vector.delete() - click.echo( - click.style(f'Successfully delete vector index for app: {app.id}.', - fg='green')) + click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green")) except Exception as e: - click.echo( - click.style(f'Failed to delete vector index for app {app.id}.', - fg='red')) + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} annotations for app {app.id}.', - fg='green')) - vector.create(documents) click.echo( - click.style(f'Successfully created vector index for app {app.id}.', fg='green')) + click.style( + f"Start to created vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) + vector.create(documents) + click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green")) except Exception as e: - click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) raise e - click.echo(f'Successfully migrated app annotation {app.id}.') + click.echo(f"Successfully migrated app annotation {app.id}.") create_count += 1 except Exception as e: click.echo( - click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style( + "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red" + ) + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.", + fg="green", + ) + ) def migrate_knowledge_vector_database(): """ Migrate vector database datas to target vector database . """ - click.echo(click.style('Start migrate vector db.', fg='green')) + click.echo(click.style("Start migrate vector db.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -253,87 +262,77 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ - .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) + datasets = ( + db.session.query(Dataset) + .filter(Dataset.indexing_technique == "high_quality") + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break page += 1 for dataset in datasets: total_count = total_count + 1 - click.echo(f'Processing the {total_count} dataset {dataset.id}. ' - + f'{create_count} created, {skipped_count} skipped.') + click.echo( + f"Processing the {total_count} dataset {dataset.id}. " + + f"{create_count} created, {skipped_count} skipped." + ) try: - click.echo('Create dataset vdb index: {}'.format(dataset.id)) + click.echo("Create dataset vdb index: {}".format(dataset.id)) if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] == vector_type: + if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 continue - collection_name = '' + collection_name = "" if vector_type == VectorType.WEAVIATE: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.WEAVIATE, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.QDRANT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.MILVUS: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.MILVUS, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.RELYT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'relyt', - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.TENCENT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.TENCENT, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.PGVECTOR: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.PGVECTOR, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.OPENSEARCH: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.ANALYTICDB: @@ -341,16 +340,13 @@ def migrate_knowledge_vector_database(): collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name} + "vector_store": {"class_prefix": collection_name}, } dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.ELASTICSEARCH: dataset_id = dataset.id index_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'elasticsearch', - "vector_store": {"class_prefix": index_name} - } + index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -361,29 +357,41 @@ def migrate_knowledge_vector_database(): try: vector.delete() click.echo( - click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.', - fg='green')) + click.style( + f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green" + ) + ) except Exception as e: click.echo( - click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.', - fg='red')) + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" + ) + ) raise e - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .all() + ) for segment in segments: document = Document( @@ -393,7 +401,7 @@ def migrate_knowledge_vector_database(): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -401,37 +409,43 @@ def migrate_knowledge_vector_database(): if documents: try: - click.echo(click.style( - f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.', - fg='green')) + click.echo( + click.style( + f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + fg="green", + ) + ) vector.create(documents) click.echo( - click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green')) + click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green") + ) except Exception as e: - click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red')) + click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) raise e db.session.add(dataset) db.session.commit() - click.echo(f'Successfully migrated dataset {dataset.id}.') + click.echo(f"Successfully migrated dataset {dataset.id}.") create_count += 1 except Exception as e: db.session.rollback() click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) continue click.echo( - click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.', - fg='green')) + click.style( + f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green" + ) + ) -@click.command('convert-to-agent-apps', help='Convert Agent Assistant to Agent App.') +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") def convert_to_agent_apps(): """ Convert Agent Assistant to Agent App. """ - click.echo(click.style('Start convert to agent apps.', fg='green')) + click.echo(click.style("Start convert to agent apps.", fg="green")) proceeded_app_ids = [] @@ -466,7 +480,7 @@ def convert_to_agent_apps(): break for app in apps: - click.echo('Converting app: {}'.format(app.id)) + click.echo("Converting app: {}".format(app.id)) try: app.mode = AppMode.AGENT_CHAT.value @@ -478,137 +492,139 @@ def convert_to_agent_apps(): ) db.session.commit() - click.echo(click.style('Converted app: {}'.format(app.id), fg='green')) + click.echo(click.style("Converted app: {}".format(app.id), fg="green")) except Exception as e: - click.echo( - click.style('Convert app error: {} {}'.format(e.__class__.__name__, - str(e)), fg='red')) + click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) - click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green')) + click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) -@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.') -@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.') +@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.") def add_qdrant_doc_id_index(field: str): - click.echo(click.style('Start add qdrant doc_id index.', fg='green')) + click.echo(click.style("Start add qdrant doc_id index.", fg="green")) vector_type = dify_config.VECTOR_STORE if vector_type != "qdrant": - click.echo(click.style('Sorry, only support qdrant vector store.', fg='red')) + click.echo(click.style("Sorry, only support qdrant vector store.", fg="red")) return create_count = 0 try: bindings = db.session.query(DatasetCollectionBinding).all() if not bindings: - click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red')) + click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red")) return import qdrant_client from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + for binding in bindings: if dify_config.QDRANT_URL is None: - raise ValueError('Qdrant url is required.') + raise ValueError("Qdrant url is required.") qdrant_config = QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, root_path=current_app.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) # create payload index - client.create_payload_index(binding.collection_name, field, - field_schema=PayloadSchemaType.KEYWORD) + client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 except UnexpectedResponse as e: # Collection does not exist, so return if e.status_code == 404: - click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red") + ) continue # Some other error occurred, so re-raise the exception else: - click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red')) + click.echo( + click.style( + f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red" + ) + ) except Exception as e: - click.echo(click.style('Failed to create qdrant client.', fg='red')) + click.echo(click.style("Failed to create qdrant client.", fg="red")) - click.echo( - click.style(f'Congratulations! Create {create_count} collection indexes.', - fg='green')) + click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green")) -@click.command('create-tenant', help='Create account and tenant.') -@click.option('--email', prompt=True, help='The email address of the tenant account.') -@click.option('--language', prompt=True, help='Account language, default: en-US.') +@click.command("create-tenant", help="Create account and tenant.") +@click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--language", prompt=True, help="Account language, default: en-US.") def create_tenant(email: str, language: Optional[str] = None): """ Create tenant account """ if not email: - click.echo(click.style('Sorry, email is required.', fg='red')) + click.echo(click.style("Sorry, email is required.", fg="red")) return # Create account email = email.strip() - if '@' not in email: - click.echo(click.style('Sorry, invalid email address.', fg='red')) + if "@" not in email: + click.echo(click.style("Sorry, invalid email address.", fg="red")) return - account_name = email.split('@')[0] + account_name = email.split("@")[0] if language not in languages: - language = 'en-US' + language = "en-US" # generate random password new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register( - email=email, - name=account_name, - password=new_password, - language=language - ) + account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) TenantService.create_owner_tenant_if_not_exist(account) - click.echo(click.style('Congratulations! Account and tenant created.\n' - 'Account: {}\nPassword: {}'.format(email, new_password), fg='green')) + click.echo( + click.style( + "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), + fg="green", + ) + ) -@click.command('upgrade-db', help='upgrade the database') +@click.command("upgrade-db", help="upgrade the database") def upgrade_db(): - click.echo('Preparing database migration...') - lock = redis_client.lock(name='db_upgrade_lock', timeout=60) + click.echo("Preparing database migration...") + lock = redis_client.lock(name="db_upgrade_lock", timeout=60) if lock.acquire(blocking=False): try: - click.echo(click.style('Start database migration.', fg='green')) + click.echo(click.style("Start database migration.", fg="green")) # run db migration import flask_migrate + flask_migrate.upgrade() - click.echo(click.style('Database migration successful!', fg='green')) + click.echo(click.style("Database migration successful!", fg="green")) except Exception as e: - logging.exception(f'Database migration failed, error: {e}') + logging.exception(f"Database migration failed, error: {e}") finally: lock.release() else: - click.echo('Database migration skipped') + click.echo("Database migration skipped") -@click.command('fix-app-site-missing', help='Fix app related site missing issue.') +@click.command("fix-app-site-missing", help="Fix app related site missing issue.") def fix_app_site_missing(): """ Fix app related site missing issue. """ - click.echo(click.style('Start fix app related site missing issue.', fg='green')) + click.echo(click.style("Start fix app related site missing issue.", fg="green")) failed_app_ids = [] while True: @@ -639,15 +655,14 @@ where sites.id is null limit 1000""" app_was_created.send(app, account=account) except Exception as e: failed_app_ids.append(app_id) - click.echo(click.style('Fix app {} related site missing issue failed!'.format(app_id), fg='red')) - logging.exception(f'Fix app related site missing issue failed, error: {e}') + click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red")) + logging.exception(f"Fix app related site missing issue failed, error: {e}") continue if not processed_count: break - - click.echo(click.style('Congratulations! Fix app related site missing issue successful!', fg='green')) + click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green")) def register_commands(app): diff --git a/api/configs/app_config.py b/api/configs/app_config.py index b277760edd..c902138595 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -37,6 +37,7 @@ class DifyConfig( CODE_MAX_NUMBER: int = 9223372036854775807 CODE_MIN_NUMBER: int = -9223372036854775808 + CODE_MAX_DEPTH: int = 5 CODE_MAX_STRING_LENGTH: int = 80000 CODE_MAX_STRING_ARRAY_LENGTH: int = 30 CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 369b25d788..27de2c0461 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings): default=False, ) + class WorkspaceConfig(BaseSettings): """ Workspace configs @@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings): ) +class PositionConfig(BaseSettings): + + POSITION_PROVIDER_PINS: str = Field( + description='The heads of model providers', + default='', + ) + + POSITION_PROVIDER_INCLUDES: str = Field( + description='The included model providers', + default='', + ) + + POSITION_PROVIDER_EXCLUDES: str = Field( + description='The excluded model providers', + default='', + ) + + POSITION_TOOL_PINS: str = Field( + description='The heads of tools', + default='', + ) + + POSITION_TOOL_INCLUDES: str = Field( + description='The included tools', + default='', + ) + + POSITION_TOOL_EXCLUDES: str = Field( + description='The excluded tools', + default='', + ) + + @computed_field + def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != ''] + + @computed_field + def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''} + + @computed_field + def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''} + + @computed_field + def POSITION_TOOL_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != ''] + + @computed_field + def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''} + + @computed_field + def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''} + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -466,6 +524,7 @@ class FeatureConfig( UpdateConfig, WorkflowConfig, WorkspaceConfig, + PositionConfig, # hosted services config HostedServiceConfig, diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 247fcde655..a7c5eb15a3 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description='Dify version', - default='0.7.0', + default='0.7.1', ) COMMIT_SHA: str = Field( diff --git a/api/constants/__init__.py b/api/constants/__init__.py index e374c04316..e22c3268ef 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1 +1 @@ -HIDDEN_VALUE = '[__HIDDEN__]' +HIDDEN_VALUE = "[__HIDDEN__]" diff --git a/api/constants/languages.py b/api/constants/languages.py index 38e49e0d1e..524dc61b57 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,22 +1,22 @@ language_timezone_mapping = { - 'en-US': 'America/New_York', - 'zh-Hans': 'Asia/Shanghai', - 'zh-Hant': 'Asia/Taipei', - 'pt-BR': 'America/Sao_Paulo', - 'es-ES': 'Europe/Madrid', - 'fr-FR': 'Europe/Paris', - 'de-DE': 'Europe/Berlin', - 'ja-JP': 'Asia/Tokyo', - 'ko-KR': 'Asia/Seoul', - 'ru-RU': 'Europe/Moscow', - 'it-IT': 'Europe/Rome', - 'uk-UA': 'Europe/Kyiv', - 'vi-VN': 'Asia/Ho_Chi_Minh', - 'ro-RO': 'Europe/Bucharest', - 'pl-PL': 'Europe/Warsaw', - 'hi-IN': 'Asia/Kolkata', - 'tr-TR': 'Europe/Istanbul', - 'fa-IR': 'Asia/Tehran', + "en-US": "America/New_York", + "zh-Hans": "Asia/Shanghai", + "zh-Hant": "Asia/Taipei", + "pt-BR": "America/Sao_Paulo", + "es-ES": "Europe/Madrid", + "fr-FR": "Europe/Paris", + "de-DE": "Europe/Berlin", + "ja-JP": "Asia/Tokyo", + "ko-KR": "Asia/Seoul", + "ru-RU": "Europe/Moscow", + "it-IT": "Europe/Rome", + "uk-UA": "Europe/Kyiv", + "vi-VN": "Asia/Ho_Chi_Minh", + "ro-RO": "Europe/Bucharest", + "pl-PL": "Europe/Warsaw", + "hi-IN": "Asia/Kolkata", + "tr-TR": "Europe/Istanbul", + "fa-IR": "Asia/Tehran", } languages = list(language_timezone_mapping.keys()) @@ -26,6 +26,5 @@ def supported_language(lang): if lang in languages: return lang - error = ('{lang} is not a valid language.' - .format(lang=lang)) + error = "{lang} is not a valid language.".format(lang=lang) raise ValueError(error) diff --git a/api/constants/model_template.py b/api/constants/model_template.py index cc5a370254..7e1a196356 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -5,82 +5,79 @@ from models.model import AppMode default_app_templates = { # workflow default mode AppMode.WORKFLOW: { - 'app': { - 'mode': AppMode.WORKFLOW.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.WORKFLOW.value, + "enable_site": True, + "enable_api": True, } }, - # completion default mode AppMode.COMPLETION: { - 'app': { - 'mode': AppMode.COMPLETION.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.COMPLETION.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} + "completion_params": {}, }, - 'user_input_form': json.dumps([ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": True, - "default": "" - } - } - ]), - 'pre_prompt': '{{query}}' + "user_input_form": json.dumps( + [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": True, + "default": "", + }, + }, + ] + ), + "pre_prompt": "{{query}}", }, - }, - # chat default mode AppMode.CHAT: { - 'app': { - 'mode': AppMode.CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } + "completion_params": {}, + }, + }, }, - # advanced-chat default mode AppMode.ADVANCED_CHAT: { - 'app': { - 'mode': AppMode.ADVANCED_CHAT.value, - 'enable_site': True, - 'enable_api': True - } + "app": { + "mode": AppMode.ADVANCED_CHAT.value, + "enable_site": True, + "enable_api": True, + }, }, - # agent-chat default mode AppMode.AGENT_CHAT: { - 'app': { - 'mode': AppMode.AGENT_CHAT.value, - 'enable_site': True, - 'enable_api': True + "app": { + "mode": AppMode.AGENT_CHAT.value, + "enable_site": True, + "enable_api": True, }, - 'model_config': { - 'model': { + "model_config": { + "model": { "provider": "openai", "name": "gpt-4o", "mode": "chat", - "completion_params": {} - } - } - } + "completion_params": {}, + }, + }, + }, } diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index b6b18f5c5b..623a1a28eb 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -2,6 +2,6 @@ from contextvars import ContextVar from core.workflow.entities.variable_pool import VariablePool -tenant_id: ContextVar[str] = ContextVar('tenant_id') +tenant_id: ContextVar[str] = ContextVar("tenant_id") -workflow_variable_pool: ContextVar[VariablePool] = ContextVar('workflow_variable_pool') +workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2f304b970c..8651597fd7 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -61,6 +61,7 @@ class AppListApi(Resource): parser.add_argument('name', type=str, required=True, location='json') parser.add_argument('description', type=str, location='json') parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') + parser.add_argument('icon_type', type=str, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() @@ -94,6 +95,7 @@ class AppImportApi(Resource): parser.add_argument('data', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, location='json') parser.add_argument('description', type=str, location='json') + parser.add_argument('icon_type', type=str, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() @@ -167,6 +169,7 @@ class AppApi(Resource): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('description', type=str, location='json') + parser.add_argument('icon_type', type=str, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') parser.add_argument('max_active_requests', type=int, location='json') @@ -208,6 +211,7 @@ class AppCopyApi(Resource): parser = reqparse.RequestParser() parser.add_argument('name', type=str, location='json') parser.add_argument('description', type=str, location='json') + parser.add_argument('icon_type', type=str, location='json') parser.add_argument('icon', type=str, location='json') parser.add_argument('icon_background', type=str, location='json') args = parser.parse_args() diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 844788a9e3..995264541f 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -33,7 +33,7 @@ class CompletionConversationApi(Resource): @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_pagination_fields) def get(self, app_model): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument('keyword', type=str, location='args') @@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource): @get_app_model(mode=AppMode.COMPLETION) @marshal_with(conversation_message_detail_fields) def get(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) @@ -119,7 +119,7 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def delete(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) @@ -154,6 +154,8 @@ class ChatConversationApi(Resource): parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], + required=False, default='-updated_at', location='args') args = parser.parse_args() subquery = ( @@ -225,7 +227,17 @@ class ChatConversationApi(Resource): if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - query = query.order_by(Conversation.created_at.desc()) + match args['sort_by']: + case 'created_at': + query = query.order_by(Conversation.created_at.asc()) + case '-created_at': + query = query.order_by(Conversation.created_at.desc()) + case 'updated_at': + query = query.order_by(Conversation.updated_at.asc()) + case '-updated_at': + query = query.order_by(Conversation.updated_at.desc()) + case _: + query = query.order_by(Conversation.created_at.desc()) conversations = db.paginate( query, @@ -256,7 +268,7 @@ class ChatConversationDetailApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required def delete(self, app_model, conversation_id): - if not current_user.is_admin_or_owner: + if not current_user.is_editor: raise Forbidden() conversation_id = str(conversation_id) diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 6aa9f0b475..7db58c048a 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -16,6 +16,7 @@ from models.model import Site def parse_app_site_args(): parser = reqparse.RequestParser() parser.add_argument('title', type=str, required=False, location='json') + parser.add_argument('icon_type', type=str, required=False, location='json') parser.add_argument('icon', type=str, required=False, location='json') parser.add_argument('icon_background', type=str, required=False, location='json') parser.add_argument('description', type=str, required=False, location='json') @@ -53,6 +54,7 @@ class AppSite(Resource): for attr_name in [ 'title', + 'icon_type', 'icon', 'icon_background', 'description', diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6eb97b6c81..a2052b9764 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -459,6 +459,7 @@ class ConvertToWorkflowApi(Resource): if request.data: parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=False, nullable=True, location='json') + parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json') parser.add_argument('icon', type=str, required=False, nullable=True, location='json') parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') args = parser.parse_args() diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a5bc2dd86a..f2a9a965ae 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -573,13 +573,13 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: + case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS: return { 'retrieval_method': [ RetrievalMethod.SEMANTIC_SEARCH.value ] } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR: return { 'retrieval_method': [ RetrievalMethod.SEMANTIC_SEARCH.value, diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 44bda8e771..4598a14611 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -25,6 +25,8 @@ class ConversationApi(Resource): parser = reqparse.RequestParser() parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], + required=False, default='-updated_at', location='args') args = parser.parse_args() try: @@ -33,7 +35,8 @@ class ConversationApi(Resource): user=end_user, last_id=args['last_id'], limit=args['limit'], - invoke_from=InvokeFrom.SERVICE_API + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args['sort_by'] ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0849eb72ba..0fa2aa65b2 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -53,19 +53,22 @@ class SegmentApi(DatasetApiResource): raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args parser = reqparse.RequestParser() parser.add_argument('segments', type=list, required=False, nullable=True, location='json') args = parser.parse_args() - for args_item in args['segments']: - SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args['segments'], document, dataset) - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form - }, 200 + if args['segments'] is not None: + for args_item in args['segments']: + SegmentService.segment_create_args_validate(args_item, document) + segments = SegmentService.multi_create_segment(args['segments'], document, dataset) + return { + 'data': marshal(segments, segment_fields), + 'doc_form': document.doc_form + }, 200 + else: + return {"error": "Segemtns is required"}, 400 def get(self, tenant_id, dataset_id, document_id): """Create single segment.""" diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index b83ea3a525..334ee382a2 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -26,6 +26,8 @@ class ConversationListApi(WebApiResource): parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], + required=False, default='-updated_at', location='args') args = parser.parse_args() pinned = None @@ -40,6 +42,7 @@ class ConversationListApi(WebApiResource): limit=args['limit'], invoke_from=InvokeFrom.WEB_APP, pinned=pinned, + sort_by=args['sort_by'] ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 99ec86e935..0f4a7cabe5 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -6,6 +6,7 @@ from configs import dify_config from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from libs.helper import AppIconUrlField from models.account import TenantStatus from models.model import Site from services.feature_service import FeatureService @@ -28,8 +29,10 @@ class AppSiteApi(WebApiResource): 'title': fields.String, 'chat_color_theme': fields.String, 'chat_color_theme_inverted': fields.Boolean, + 'icon_type': fields.String, 'icon': fields.String, 'icon_background': fields.String, + 'icon_url': AppIconUrlField, 'description': fields.String, 'copyright': fields.String, 'privacy_policy': fields.String, diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0ff9389793..d8290ca608 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -64,15 +64,19 @@ class BaseAgentRunner(AppRunner): """ Agent runner :param tenant_id: tenant id + :param application_generate_entity: application generate entity + :param conversation: conversation :param app_config: app generate entity :param model_config: model config :param config: dataset config :param queue_manager: queue manager :param message: message :param user_id: user id - :param agent_llm_callback: agent llm callback - :param callback: callback :param memory: memory + :param prompt_messages: prompt messages + :param variables_pool: variables pool + :param db_variables: db variables + :param model_instance: model instance """ self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -445,7 +449,7 @@ class BaseAgentRunner(AppRunner): try: tool_responses = json.loads(agent_thought.observation) except Exception as e: - tool_responses = { tool: agent_thought.observation for tool in tools } + tool_responses = dict.fromkeys(tools, agent_thought.observation) for tool in tools: # generate a uuid for tool call diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 06492bb12f..89c948d2e2 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -292,6 +292,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): handle invoke action :param action: action :param tool_instances: tool instances + :param message_file_ids: message file ids + :param trace_manager: trace manager :return: observation, meta """ # action is tool call, invoke tool diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 3eb006b46e..15fa4d99fd 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,6 +1,6 @@ import re -from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType from core.external_data_tool.factory import ExternalDataToolFactory @@ -13,7 +13,7 @@ class BasicVariablesConfigManager: :param config: model config args """ external_data_variables = [] - variables = [] + variable_entities = [] # old external_data_tools external_data_tools = config.get('external_data_tools', []) @@ -30,50 +30,41 @@ class BasicVariablesConfigManager: ) # variables and external_data_tools - for variable in config.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - if 'config' not in val: + for variables in config.get('user_input_form', []): + variable_type = list(variables.keys())[0] + if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: + variable = variables[variable_type] + if 'config' not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] + variable=variable['variable'], + type=variable['type'], + config=variable['config'] ) ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, + elif variable_type in [ + VariableEntityType.TEXT_INPUT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.SELECT, ]: - variables.append( + variable = variables[variable_type] + variable_entities.append( VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - variables.append( - VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), + type=variable_type, + variable=variable.get('variable'), + description=variable.get('description'), + label=variable.get('label'), + required=variable.get('required', False), + max_length=variable.get('max_length'), + options=variable.get('options'), + default=variable.get('default'), ) ) - return variables, external_data_variables + return variable_entities, external_data_variables @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: @@ -183,4 +174,4 @@ class BasicVariablesConfigManager: config=config ) - return config, ["external_data_tools"] \ No newline at end of file + return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 05a42a898e..bbb10d3d76 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntityType(str, Enum): + TEXT_INPUT = "text-input" + SELECT = "select" + PARAGRAPH = "paragraph" + NUMBER = "number" + EXTERNAL_DATA_TOOL = "external-data-tool" + + class VariableEntity(BaseModel): """ Variable Entity. """ - class Type(Enum): - TEXT_INPUT = 'text-input' - SELECT = 'select' - PARAGRAPH = 'paragraph' - NUMBER = 'number' - - @classmethod - def value_of(cls, value: str) -> 'VariableEntity.Type': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid variable type value {value}') variable: str label: str description: Optional[str] = None - type: Type + type: VariableEntityType required: bool = False max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None hint: Optional[str] = None - @property - def name(self) -> str: - return self.variable - class ExternalDataVariableEntity(BaseModel): """ @@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str \ No newline at end of file + workflow_id: str diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 351eb05d8a..5a1e5973cd 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,7 +29,7 @@ from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import App, Conversation, EndUser, Message @@ -46,7 +46,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): args: dict, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ): """ Generate App response. @@ -73,8 +73,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # parse files files = args['files'] if args.get('files') else [] @@ -133,8 +134,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): node_id: str, user: Account, args: dict, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): """ Generate App response. @@ -157,8 +157,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get('conversation_id') + if conversation_id: + conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user) # convert to app config app_config = AdvancedChatAppConfigManager.get_app_config( @@ -200,8 +201,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, conversation: Conversation | None = None, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + stream: bool = True): is_first_conversation = False if not conversation: is_first_conversation = True @@ -270,11 +270,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create a variable pool. system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation_id, - SystemVariable.USER_ID: user_id, - SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: conversation_id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, } variable_pool = VariablePool( system_variables=system_inputs, @@ -362,7 +362,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f8efcb5960..2b3596ded2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -49,7 +49,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created @@ -74,7 +74,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _workflow: Workflow _user: Union[Account, EndUser] # Deprecated - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__( @@ -108,10 +108,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._message = message # Deprecated self._workflow_system_variables = { - SystemVariable.QUERY: message.query, - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id, + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_id, } self._task_state = AdvancedChatTaskState( @@ -249,8 +249,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ for message in self._queue_manager.listen(): if (message.event - and hasattr(message.event, 'metadata') - and message.event.metadata + and getattr(message.event, 'metadata', None) and message.event.metadata.get('is_answer_previous_node', False) and publisher): publisher.publish(message=message) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6f48aa2363..9e331dff4d 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import AppConfig, VariableEntity +from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType class BaseAppGenerator: @@ -9,29 +9,29 @@ class BaseAppGenerator: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} + filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} return filtered_inputs def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.name) + user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.name} is required in input form') + raise ValueError(f'{var.variable} is required in input form') if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? return var.default or '' if ( var.type in ( - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.SELECT, - VariableEntity.Type.PARAGRAPH, + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, ) and user_input_value and not isinstance(user_input_value, str) ): - raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") - if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): + raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") + if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: if '.' in user_input_value: @@ -39,14 +39,14 @@ class BaseAppGenerator: else: return int(user_input_value) except ValueError: - raise ValueError(f"{var.name} in input form must be a valid number") - if var.type == VariableEntity.Type.SELECT: + raise ValueError(f"{var.variable} in input form must be a valid number") + if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.name} in input form must be one of the following: {options}') - elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): + raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') + raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') return user_input_value diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 58c7d04b83..2c5feaaaaf 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None) -> int: """ Get pre calculate rest tokens @@ -126,7 +128,7 @@ class AppRunner: model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ @@ -254,6 +256,7 @@ class AppRunner: :param invoke_result: invoke result :param queue_manager: application queue manager :param stream: stream + :param agent: agent :return: """ if not stream: @@ -276,6 +279,7 @@ class AppRunner: Handle invoke result direct :param invoke_result: invoke result :param queue_manager: application queue manager + :param agent: agent :return: """ queue_manager.publish( @@ -291,6 +295,7 @@ class AppRunner: Handle invoke result :param invoke_result: invoke result :param queue_manager: application queue manager + :param agent: agent :return: """ model = None @@ -366,7 +371,7 @@ class AppRunner: message_id=message_id, trace_manager=app_generate_entity.trace_manager ) - + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: @@ -418,7 +423,7 @@ class AppRunner: inputs=inputs, query=query ) - + def query_app_annotations_to_reply(self, app_record: App, message: Message, query: str, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 12f69f1528..fceed95b91 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from datetime import datetime, timezone from typing import Optional, Union from sqlalchemy import and_ @@ -36,17 +37,17 @@ logger = logging.getLogger(__name__) class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response( - self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool = False, + self, application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool = False, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, @@ -138,6 +139,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): """ Initialize generate records :param application_generate_entity: application generate entity + :conversation conversation :return: """ app_config = application_generate_entity.app_config @@ -192,6 +194,9 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.add(conversation) db.session.commit() db.session.refresh(conversation) + else: + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db.session.commit() message = Message( app_id=app_config.app_id, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 994919391e..e388d0184b 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db @@ -67,8 +67,8 @@ class WorkflowAppRunner: # Create a variable pool. system_inputs = { - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id, + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, } variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 5022eb0438..de8542d7b9 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -43,7 +43,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.end.end_node import EndNode from extensions.ext_database import db from models.account import Account @@ -67,7 +67,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, @@ -92,8 +92,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._workflow = workflow self._workflow_system_variables = { - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.USER_ID: user_id + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_id } self._task_state = WorkflowTaskState( diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9a861c29e2..6a1ab23041 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None \ No newline at end of file + single_iteration_run: Optional[SingleIterationRunEntity] = None diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index 321bc0ad02..5c713cac67 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -99,7 +99,13 @@ class ObjectSegment(Segment): class ArraySegment(Segment): @property def markdown(self) -> str: - return '\n'.join(['- ' + item.markdown for item in self.value]) + items = [] + for item in self.value: + if hasattr(item, 'to_markdown'): + items.append(item.to_markdown()) + else: + items.append(str(item)) + return '\n'.join(items) class ArrayAnySegment(ArraySegment): diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 8baa8ba09e..bd98c82720 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -2,7 +2,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from models.account import Account from models.model import EndUser from models.workflow import Workflow @@ -13,4 +13,4 @@ class WorkflowCycleStateManager: _workflow: Workflow _user: Union[Account, EndUser] _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] + _workflow_system_variables: dict[SystemVariableKey, Any] diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 01b89907db..085ff07cfd 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -99,7 +99,7 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: + def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): """ transform message files @@ -144,7 +144,7 @@ class MessageFileParser: return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: + def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): """ transform file to file obj diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index dd1534c791..8cf184ac44 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -3,6 +3,7 @@ from collections import OrderedDict from collections.abc import Callable from typing import Any +from configs import dify_config from core.tools.utils.yaml_utils import load_yaml_file @@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> return {name: index for index, name in enumerate(positions)} +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + ) + + +def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: + """ + Pin the items in the pin list to the beginning of the position map. + Overall logic: exclude > include > pin + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :return: the sorted position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list)} + + # Add remaining positions to position map + start_idx = len(position_map) + for name in positions: + if name not in position_map: + position_map[name] = start_idx + start_idx += 1 + + return position_map + + +def is_filtered( + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], +) -> bool: + """ + Chcek if the object should be filtered out. + Overall logic: exclude > include > pin + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :param name_func: the function to get the name of the object + :param data: the data to be filtered + :return: True if the object should be filtered out, False otherwise + """ + if not data: + return False + if not include_set and not exclude_set: + return False + + name = name_func(data) + + if name in exclude_set: # exclude_set is prioritized + return True + if include_set and name not in include_set: # filter out only if include_set is not empty + return True + return False + + def sort_by_position_map( position_map: dict[str, int], data: list[Any], diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b20c6ed187..8173028ed7 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -700,6 +700,7 @@ class IndexingRunner: DatasetDocument.tokens: tokens, DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, + DatasetDocument.error: None, } ) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 8e99ad3dec..1ceed8043c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -271,9 +271,8 @@ class ModelInstance: :param content_text: text content to be translated :param tenant_id: user tenant id - :param user: unique user id :param voice: model timbre - :param streaming: output is streaming + :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -369,6 +368,15 @@ class ModelManager: return ModelInstance(provider_model_bundle, model) + def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Return first provider and the first model in the provider + :param tenant_id: tenant id + :param model_type: model type + :return: provider name, model name + """ + return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) + def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: """ Get default model instance @@ -401,6 +409,10 @@ class LBModelManager: managed_credentials: Optional[dict] = None) -> None: """ Load balancing model manager + :param tenant_id: tenant_id + :param provider: provider + :param model_type: model_type + :param model: model name :param load_balancing_configs: all load balancing configurations :param managed_credentials: credentials if load balancing configuration name is __inherit__ """ @@ -499,7 +511,6 @@ class LBModelManager: config.id ) - res = redis_client.exists(cooldown_cache_key) res = cast(bool, res) return res diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 87fe4f681c..d2076bf74a 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,4 +1,3 @@ - from core.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { @@ -94,5 +93,16 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { }, 'required': False, 'options': ['JSON', 'XML'], - } -} \ No newline at end of file + }, + DefaultParameterName.JSON_SCHEMA: { + 'label': { + 'en_US': 'JSON Schema', + }, + 'type': 'text', + 'help': { + 'en_US': 'Set a response json schema will ensure LLM to adhere it.', + 'zh_Hans': '设置返回的json schema,llm将按照它返回', + }, + 'required': False, + }, +} diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3d471787bb..c257ce63d2 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -95,6 +95,7 @@ class DefaultParameterName(Enum): FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" RESPONSE_FORMAT = "response_format" + JSON_SCHEMA = "json_schema" @classmethod def value_of(cls, value: Any) -> 'DefaultParameterName': @@ -118,6 +119,7 @@ class ParameterType(Enum): INT = "int" STRING = "string" BOOLEAN = "boolean" + TEXT = "text" class ModelPropertyKey(Enum): diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 0de216bf89..716bb63566 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -151,9 +151,9 @@ class AIModel(ABC): os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + and not model_schema_yaml.startswith('_') + and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) + and model_schema_yaml.endswith('.yaml') ] # get _position.yaml file path diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 02ba0c9410..cfc8942c79 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -185,7 +185,7 @@ if you are not sure about the structure. stream=stream, user=user ) - + model_parameters.pop("response_format") stop = stop or [] stop.extend(["\n```", "```\n"]) @@ -249,10 +249,10 @@ if you are not sure about the structure. prompt_messages=prompt_messages, input_generator=new_generator() ) - + return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], + def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] ) -> Generator[LLMResultChunk, None, None]: """ @@ -310,7 +310,7 @@ if you are not sure about the structure. ) ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, + def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]) \ -> Generator[LLMResultChunk, None, None]: """ @@ -470,7 +470,7 @@ if you are not sure about the structure. :return: full response or stream response chunk generator result """ raise NotImplementedError - + @abstractmethod def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -792,6 +792,13 @@ if you are not sure about the structure. if not isinstance(parameter_value, str): raise ValueError(f"Model Parameter {parameter_name} should be string.") + # validate options + if parameter_rule.options and parameter_value not in parameter_rule.options: + raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") + elif parameter_rule.type == ParameterType.TEXT: + if not isinstance(parameter_value, str): + raise ValueError(f"Model Parameter {parameter_name} should be text.") + # validate options if parameter_rule.options and parameter_value not in parameter_rule.options: raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 3d2bac1c31..f9ddd86f68 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -70,7 +70,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - # max font is 4096,there is 3500 limit for each request + # max length is 4096 characters, there is 3500 limit for each request max_length = 3500 if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b1660afafb..e2d17e3257 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -6,7 +6,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map +from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -234,7 +234,7 @@ class ModelProviderFactory: ] # get _position.yaml file path - position_map = get_position_map(model_providers_path) + position_map = get_provider_position_map(model_providers_path) # traverse all model_provider_dir_paths model_providers: list[ModelProviderExtension] = [] diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 17cf65dc3a..c233596637 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -84,7 +84,8 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): def _add_custom_parameters(self, credentials: dict) -> None: credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": + credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml index 34c802c2a7..41e9c2e808 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml @@ -31,6 +31,14 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + placeholder: + zh_Hans: Base URL, 如:https://api.moonshot.cn/v1 + en_US: Base URL, e.g. https://api.moonshot.cn/v1 model_credential_schema: model: label: diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index 21661b9a2b..ac7313aaa1 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -2,6 +2,7 @@ - gpt-4o - gpt-4o-2024-05-13 - gpt-4o-2024-08-06 +- chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 - gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml new file mode 100644 index 0000000000..98e236650c --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -0,0 +1,44 @@ +model: chatgpt-4o-latest +label: + zh_Hans: chatgpt-4o-latest + en_US: chatgpt-4o-latest +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml index cf2de0f73a..7e430c51a7 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '2.50' output: '10.00' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index b97fbf8aab..23dcf85085 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index aae2729bdf..06135c9584 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,3 +1,4 @@ +import json import logging from collections.abc import Generator from typing import Optional, Union, cast @@ -544,13 +545,18 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): response_format = model_parameters.get("response_format") if response_format: - if response_format == "json_object": - response_format = {"type": "json_object"} + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not currect json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: - response_format = {"type": "text"} - - model_parameters["response_format"] = response_format - + model_parameters["response_format"] = {"type": response_format} extra_model_kwargs = {} @@ -922,11 +928,14 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" if model.startswith('ft:'): model = model.split(':')[1] + # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. + if model == "chatgpt-4o-latest": + model = "gpt-4o" + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -946,7 +955,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): raise NotImplementedError( f"get_num_tokens_from_messages() is not presently implemented " f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "See https://platform.openai.com/docs/advanced-usage/managing-tokens for " "information on how messages are converted to tokens." ) num_tokens = 0 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml new file mode 100644 index 0000000000..cf2de0f73a --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-2024-08-06 +label: + zh_Hans: gpt-4o-2024-08-06 + en_US: gpt-4o-2024-08-06 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml new file mode 100644 index 0000000000..87712874b9 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml @@ -0,0 +1,61 @@ +model: Llama3-Chinese_v2 +label: + en_US: Llama3-Chinese_v2 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml new file mode 100644 index 0000000000..f16f3de60b --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3-70B-Instruct-GPTQ-Int4 +label: + en_US: Meta-Llama-3-70B-Instruct-GPTQ-Int4 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 1024 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml new file mode 100644 index 0000000000..21267c240b --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3-8B-Instruct +label: + en_US: Meta-Llama-3-8B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml new file mode 100644 index 0000000000..80c7ec40f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +label: + en_US: Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 410960 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-8B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-8B-Instruct.yaml new file mode 100644 index 0000000000..bbab46344c --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-8B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Meta-Llama-3.1-8B-Instruct +label: + en_US: Meta-Llama-3.1-8B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.1 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml index af6fb91cd9..ec6d9bcc14 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen-14B-Chat-Int4.yaml @@ -55,7 +55,8 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml index 4ab9a80055..b561a53039 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-110B-Chat-GPTQ-Int4.yaml @@ -55,7 +55,8 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml index 4a8b1cf479..841dd97f35 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml @@ -6,7 +6,7 @@ features: - agent-thought model_properties: mode: chat - context_size: 8192 + context_size: 2048 parameter_rules: - name: temperature use_template: temperature @@ -55,7 +55,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml index b076504493..33d5d12b22 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml @@ -6,7 +6,7 @@ features: - agent-thought model_properties: mode: completion - context_size: 8192 + context_size: 32768 parameter_rules: - name: temperature use_template: temperature @@ -55,7 +55,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml index e24a69fe63..62255cc7d2 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml @@ -8,12 +8,12 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 8192 + context_size: 2048 parameter_rules: - name: temperature use_template: temperature type: float - default: 0.3 + default: 0.7 min: 0.0 max: 2.0 help: @@ -57,7 +57,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct.yaml new file mode 100644 index 0000000000..cea6560295 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Qwen2-72B-Instruct +label: + en_US: Qwen2-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml index e3d804729d..2f3f1f0225 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml @@ -8,7 +8,7 @@ features: - stream-tool-call model_properties: mode: completion - context_size: 8192 + context_size: 32768 parameter_rules: - name: temperature use_template: temperature @@ -57,7 +57,7 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. pricing: - input: '0.000' - output: '0.000' - unit: '0.000' + input: "0.000" + output: "0.000" + unit: "0.000" currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml index b95f6bdc1b..2c9eac0e49 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml @@ -1,6 +1,15 @@ +- Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +- Meta-Llama-3.1-8B-Instruct +- Meta-Llama-3-70B-Instruct-GPTQ-Int4 +- Meta-Llama-3-8B-Instruct - Qwen2-72B-Instruct-GPTQ-Int4 +- Qwen2-72B-Instruct - Qwen2-7B -- Qwen1.5-110B-Chat-GPTQ-Int4 +- Qwen-14B-Chat-Int4 - Qwen1.5-72B-Chat-GPTQ-Int4 - Qwen1.5-7B -- Qwen-14B-Chat-Int4 +- Qwen1.5-110B-Chat-GPTQ-Int4 +- deepseek-v2-chat +- deepseek-v2-lite-chat +- Llama3-Chinese_v2 +- chatglm3-6b diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml new file mode 100644 index 0000000000..f9c26b7f90 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml @@ -0,0 +1,61 @@ +model: chatglm3-6b +label: + en_US: chatglm3-6b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml new file mode 100644 index 0000000000..078922ef95 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml @@ -0,0 +1,61 @@ +model: deepseek-v2-chat +label: + en_US: deepseek-v2-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml new file mode 100644 index 0000000000..4ff3af7b51 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml @@ -0,0 +1,61 @@ +model: deepseek-v2-lite-chat +label: + en_US: deepseek-v2-lite-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 2048 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-en-v1.5.yaml b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-en-v1.5.yaml new file mode 100644 index 0000000000..5756fb3d14 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-en-v1.5.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-large-en-v1.5 +model_type: text-embedding +model_properties: + context_size: 32768 diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-zh-v1.5.yaml b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-zh-v1.5.yaml new file mode 100644 index 0000000000..4204ab2860 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/BAAI-bge-large-zh-v1.5.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-large-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 32768 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py b/api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 0000000000..ff3635bfeb --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,4 @@ +model: netease-youdao/bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 512 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 0000000000..807f531b08 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 8192 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py new file mode 100644 index 0000000000..6835915816 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -0,0 +1,87 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class SiliconflowRerankModel(RerankModel): + + def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], + score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) -> RerankResult: + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') + if base_url.endswith('/'): + base_url = base_url[:-1] + try: + response = httpx.post( + base_url + '/rerank', + json={ + "model": model, + "query": query, + "documents": docs, + "top_n": top_n, + "return_documents": True + }, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results['results']: + rerank_document = RerankDocument( + index=result['index'], + text=result['document']['text'], + score=result['relevance_score'], + ) + if score_threshold is None or result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + try: + + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index 1ebb1e6d8b..c46a891604 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -12,10 +12,11 @@ help: en_US: Get your API Key from SiliconFlow zh_Hans: 从 SiliconFlow 获取 API Key url: - en_US: https://cloud.siliconflow.cn/keys + en_US: https://cloud.siliconflow.cn/account/ak supported_model_types: - llm - text-embedding + - rerank - speech2text configurate_methods: - predefined-model diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index a75db78d8c..4e1bb0a5a4 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -159,6 +159,8 @@ You should also complete the text started with ``` but not tell ``` directly. """ if model in ['qwen-turbo-chat', 'qwen-plus-chat']: model = model.replace('-chat', '') + if model == 'farui-plus': + model = 'qwen-farui-plus' if model in self.tokenizers: tokenizer = self.tokenizers[model] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 8bea30324b..add5822bef 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -35,7 +35,10 @@ from core.model_runtime.model_providers.volcengine_maas.errors import ( RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.llm.models import ( + get_model_config, + get_v2_req_params, +) from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException logger = logging.getLogger(__name__) @@ -95,37 +98,12 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): -> LLMResult | Generator: client = MaaSClient.from_credential(credentials) - - req_params = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).copy() - if credentials.get('context_size'): - req_params['max_prompt_tokens'] = credentials.get('context_size') - if credentials.get('max_tokens'): - req_params['max_new_tokens'] = credentials.get('max_tokens') - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') - if stop: - req_params['stop'] = stop - + req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} - if tools: extra_model_kwargs['tools'] = [ MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools ] - resp = MaaSClient.wrap_exception( lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) if not stream: @@ -197,10 +175,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): """ used to define customizable model schema """ - max_tokens = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens') - if credentials.get('max_tokens'): - max_tokens = int(credentials.get('max_tokens')) + model_config = get_model_config(credentials) + rules = [ ParameterRule( name='temperature', @@ -234,10 +210,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): name='presence_penalty', type=ParameterType.FLOAT, use_template='presence_penalty', - label={ - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', - }, + label=I18nObject( + en_US='Presence Penalty', + zh_Hans= '存在惩罚', + ), min=-2.0, max=2.0, ), @@ -245,10 +221,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): name='frequency_penalty', type=ParameterType.FLOAT, use_template='frequency_penalty', - label={ - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', - }, + label=I18nObject( + en_US= 'Frequency Penalty', + zh_Hans= '频率惩罚', + ), min=-2.0, max=2.0, ), @@ -257,7 +233,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): type=ParameterType.INT, use_template='max_tokens', min=1, - max=max_tokens, + max=model_config.properties.max_tokens, default=512, label=I18nObject( zh_Hans='最大生成长度', @@ -266,17 +242,10 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): ), ] - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('mode'): - model_properties[ModelPropertyKey.MODE] = credentials.get('mode') - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - - model_features = ModelConfigs.get( - credentials['base_model_name'], {}).get('features', []) - + model_properties = {} + model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size + model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value + entity = AIModelEntity( model=model, label=I18nObject( @@ -286,7 +255,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): model_type=ModelType.LLM, model_properties=model_properties, parameter_rules=rules, - features=model_features, + features=model_config.features, ) return entity diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index 3e5938f3b4..c5e53d8955 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,181 +1,123 @@ +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.model_entities import ModelFeature -ModelConfigs = { - 'Doubao-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Skylark2-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4000, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-8B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-70B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-8k': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 16384, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 65536, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B-Fin': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Mistral-7B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 2048, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - } + +class ModelProperties(BaseModel): + context_size: int + max_tokens: int + mode: LLMMode + +class ModelConfig(BaseModel): + properties: ModelProperties + features: list[ModelFeature] + + +configs: dict[str, ModelConfig] = { + 'Doubao-pro-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-pro-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-pro-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Doubao-lite-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL] + ), + 'Skylark2-pro-4k': ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT), + features=[] + ), + 'Llama3-8B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), + features=[] + ), + 'Llama3-70B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-8k': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-32k': ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), + features=[] + ), + 'Moonshot-v1-128k': ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), + features=[] + ), + 'GLM3-130B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'GLM3-130B-Fin': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[] + ), + 'Mistral-7B': ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), + features=[] + ) } + +def get_model_config(credentials: dict)->ModelConfig: + base_model = credentials.get('base_model_name', '') + model_configs = configs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get('context_size', 0)), + max_tokens=int(credentials.get('max_tokens', 0)), + mode= LLMMode.value_of(credentials.get('mode', 'chat')), + ), + features=[] + ) + return model_configs + + +def get_v2_req_params(credentials: dict, model_parameters: dict, + stop: list[str] | None=None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params['max_prompt_tokens'] = model_configs.properties.context_size + req_params['max_new_tokens'] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get('max_tokens'): + req_params['max_new_tokens'] = model_parameters.get('max_tokens') + if model_parameters.get('temperature'): + req_params['temperature'] = model_parameters.get('temperature') + if model_parameters.get('top_p'): + req_params['top_p'] = model_parameters.get('top_p') + if model_parameters.get('top_k'): + req_params['top_k'] = model_parameters.get('top_k') + if model_parameters.get('presence_penalty'): + req_params['presence_penalty'] = model_parameters.get( + 'presence_penalty') + if model_parameters.get('frequency_penalty'): + req_params['frequency_penalty'] = model_parameters.get( + 'frequency_penalty') + + if stop: + req_params['stop'] = stop + + return req_params \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 569f89e975..2d8f972b94 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -1,9 +1,27 @@ +from pydantic import BaseModel + + +class ModelProperties(BaseModel): + context_size: int + max_chunks: int + +class ModelConfig(BaseModel): + properties: ModelProperties + ModelConfigs = { - 'Doubao-embedding': { - 'req_params': {}, - 'model_properties': { - 'context_size': 4096, - 'max_chunks': 1, - } - }, + 'Doubao-embedding': ModelConfig( + properties=ModelProperties(context_size=4096, max_chunks=1) + ), } + +def get_model_config(credentials: dict)->ModelConfig: + base_model = credentials.get('base_model_name', '') + model_configs = ModelConfigs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get('context_size', 0)), + max_chunks=int(credentials.get('max_chunks', 0)), + ) + ) + return model_configs \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 10b01c0d0d..8ac632369e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -30,7 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import ( RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs +from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException @@ -115,14 +115,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): """ generate custom model entities from credentials """ - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - if credentials.get('max_chunks'): - model_properties[ModelPropertyKey.MAX_CHUNKS] = int( - credentials.get('max_chunks', 4096)) + model_config = get_model_config(credentials) + model_properties = {} + model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size + model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks entity = AIModelEntity( model=model, label=I18nObject(en_US=model), diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py new file mode 100644 index 0000000000..0230c78b75 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -0,0 +1,198 @@ +from datetime import datetime, timedelta +from threading import Lock + +from requests import post + +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + +baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens_lock = Lock() + + +class BaiduAccessToken: + api_key: str + access_token: str + expires: datetime + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.access_token = '' + self.expires = datetime.now() + timedelta(days=3) + + @staticmethod + def _get_access_token(api_key: str, secret_key: str) -> str: + """ + request access token from Baidu + """ + try: + response = post( + url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json' + }, + ) + except Exception as e: + raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') + + resp = response.json() + if 'error' in resp: + if resp['error'] == 'invalid_client': + raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') + elif resp['error'] == 'unknown_error': + raise InternalServerError(f'Internal server error: {resp["error_description"]}') + elif resp['error'] == 'invalid_request': + raise BadRequestError(f'Bad request: {resp["error_description"]}') + elif resp['error'] == 'rate_limit_exceeded': + raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') + else: + raise Exception(f'Unknown error: {resp["error_description"]}') + + return resp['access_token'] + + @staticmethod + def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': + """ + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) + + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. + """ + + # loop up cache, remove expired access token + baidu_access_tokens_lock.acquire() + now = datetime.now() + for key in list(baidu_access_tokens.keys()): + token = baidu_access_tokens[key] + if token.expires < now: + baidu_access_tokens.pop(key) + + if api_key not in baidu_access_tokens: + # if access token not in cache, request it + token = BaiduAccessToken(api_key) + baidu_access_tokens[api_key] = token + # release it to enhance performance + # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock + baidu_access_tokens_lock.release() + # try to get access token + token_str = BaiduAccessToken._get_access_token(api_key, secret_key) + token.access_token = token_str + token.expires = now + timedelta(days=3) + return token + else: + # if access token in cache, return it + token = baidu_access_tokens[api_key] + baidu_access_tokens_lock.release() + return token + + +class _CommonWenxin: + api_bases = { + 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', + 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', + 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', + 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', + 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', + 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', + 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', + 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', + 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', + 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', + 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', + 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', + 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', + 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', + 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', + 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', + 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', + 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', + 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', + 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', + } + + function_calling_supports = [ + 'ernie-bot', + 'ernie-bot-8k', + 'ernie-3.5-8k', + 'ernie-3.5-8k-0205', + 'ernie-3.5-8k-1222', + 'ernie-3.5-4k-0205', + 'ernie-3.5-128k', + 'ernie-4.0-8k', + 'ernie-4.0-turbo-8k', + 'ernie-4.0-turbo-8k-preview', + 'yi_34b_chat' + ] + + api_key: str = '' + secret_key: str = '' + + def __init__(self, api_key: str, secret_key: str): + self.api_key = api_key + self.secret_key = secret_key + + @staticmethod + def _to_credential_kwargs(credentials: dict) -> dict: + credentials_kwargs = { + "api_key": credentials['api_key'], + "secret_key": credentials['secret_key'] + } + return credentials_kwargs + + def _handle_error(self, code: int, msg: str): + error_map = { + 1: InternalServerError, + 2: InternalServerError, + 3: BadRequestError, + 4: RateLimitReachedError, + 6: InvalidAuthenticationError, + 13: InvalidAPIKeyError, + 14: InvalidAPIKeyError, + 15: InvalidAPIKeyError, + 17: RateLimitReachedError, + 18: RateLimitReachedError, + 19: RateLimitReachedError, + 100: InvalidAPIKeyError, + 111: InvalidAPIKeyError, + 200: InternalServerError, + 336000: InternalServerError, + 336001: BadRequestError, + 336002: BadRequestError, + 336003: BadRequestError, + 336004: InvalidAuthenticationError, + 336005: InvalidAPIKeyError, + 336006: BadRequestError, + 336007: BadRequestError, + 336008: BadRequestError, + 336100: InternalServerError, + 336101: BadRequestError, + 336102: BadRequestError, + 336103: BadRequestError, + 336104: BadRequestError, + 336105: BadRequestError, + 336200: InternalServerError, + 336303: BadRequestError, + 337006: BadRequestError + } + + if code in error_map: + raise error_map[code](msg) + else: + raise InternalServerError(f'Unknown error: {msg}') + + def _get_access_token(self) -> str: + token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) + return token.access_token diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index e345663d36..8109949b1d 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,102 +1,17 @@ from collections.abc import Generator -from datetime import datetime, timedelta from enum import Enum from json import dumps, loads -from threading import Lock from typing import Any, Union from requests import Response, post from core.model_runtime.entities.message_entities import PromptMessageTool -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( BadRequestError, InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, ) -# map api_key to access_token -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} -baidu_access_tokens_lock = Lock() - -class BaiduAccessToken: - api_key: str - access_token: str - expires: datetime - - def __init__(self, api_key: str) -> None: - self.api_key = api_key - self.access_token = '' - self.expires = datetime.now() + timedelta(days=3) - - def _get_access_token(api_key: str, secret_key: str) -> str: - """ - request access token from Baidu - """ - try: - response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, - ) - except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') - - resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': - raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': - raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': - raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': - raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') - else: - raise Exception(f'Unknown error: {resp["error_description"]}') - - return resp['access_token'] - - @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': - """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) - - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. - """ - - # loop up cache, remove expired access token - baidu_access_tokens_lock.acquire() - now = datetime.now() - for key in list(baidu_access_tokens.keys()): - token = baidu_access_tokens[key] - if token.expires < now: - baidu_access_tokens.pop(key) - - if api_key not in baidu_access_tokens: - # if access token not in cache, request it - token = BaiduAccessToken(api_key) - baidu_access_tokens[api_key] = token - # release it to enhance performance - # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock - baidu_access_tokens_lock.release() - # try to get access token - token_str = BaiduAccessToken._get_access_token(api_key, secret_key) - token.access_token = token_str - token.expires = now + timedelta(days=3) - return token - else: - # if access token in cache, return it - token = baidu_access_tokens[api_key] - baidu_access_tokens_lock.release() - return token - class ErnieMessage: class Role(Enum): @@ -120,51 +35,7 @@ class ErnieMessage: self.content = content self.role = role -class ErnieBotModel: - api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - } - - function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' - ] - - api_key: str = '' - secret_key: str = '' - - def __init__(self, api_key: str, secret_key: str): - self.api_key = api_key - self.secret_key = secret_key +class ErnieBotModel(_CommonWenxin): def generate(self, model: str, stream: bool, messages: list[ErnieMessage], parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ @@ -199,51 +70,6 @@ class ErnieBotModel: return self._handle_chat_stream_generate_response(resp) return self._handle_chat_generate_response(resp) - def _handle_error(self, code: int, msg: str): - error_map = { - 1: InternalServerError, - 2: InternalServerError, - 3: BadRequestError, - 4: RateLimitReachedError, - 6: InvalidAuthenticationError, - 13: InvalidAPIKeyError, - 14: InvalidAPIKeyError, - 15: InvalidAPIKeyError, - 17: RateLimitReachedError, - 18: RateLimitReachedError, - 19: RateLimitReachedError, - 100: InvalidAPIKeyError, - 111: InvalidAPIKeyError, - 200: InternalServerError, - 336000: InternalServerError, - 336001: BadRequestError, - 336002: BadRequestError, - 336003: BadRequestError, - 336004: InvalidAuthenticationError, - 336005: InvalidAPIKeyError, - 336006: BadRequestError, - 336007: BadRequestError, - 336008: BadRequestError, - 336100: InternalServerError, - 336101: BadRequestError, - 336102: BadRequestError, - 336103: BadRequestError, - 336104: BadRequestError, - 336105: BadRequestError, - 336200: InternalServerError, - 336303: BadRequestError, - 337006: BadRequestError - } - - if code in error_map: - raise error_map[code](msg) - else: - raise InternalServerError(f'Unknown error: {msg}') - - def _get_access_token(self) -> str: - token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) - return token.access_token - def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py deleted file mode 100644 index 67d76b4a29..0000000000 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py +++ /dev/null @@ -1,17 +0,0 @@ -class InvalidAuthenticationError(Exception): - pass - -class InvalidAPIKeyError(Exception): - pass - -class RateLimitReachedError(Exception): - pass - -class InsufficientAccountBalance(Exception): - pass - -class InternalServerError(Exception): - pass - -class BadRequestError(Exception): - pass \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index d39d63deee..140606298c 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( - BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, -) +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken +from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage +from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure @@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): api_key = credentials['api_key'] secret_key = credentials['secret_key'] try: - BaiduAccessToken._get_access_token(api_key, secret_key) + BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') @@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): :return: Invoke error mapping """ - return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], - InvokeAuthorizationError: [ - InvalidAuthenticationError, - InsufficientAccountBalance, - InvalidAPIKeyError, - ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] - } + return invoke_error_mapping() diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml new file mode 100644 index 0000000000..74fadb7f9d --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml @@ -0,0 +1,9 @@ +model: bge-large-en +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml new file mode 100644 index 0000000000..d4af27ec38 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml @@ -0,0 +1,9 @@ +model: bge-large-zh +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml new file mode 100644 index 0000000000..eda48d9655 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml @@ -0,0 +1,9 @@ +model: embedding-v1 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml new file mode 100644 index 0000000000..e28f253eb6 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml @@ -0,0 +1,9 @@ +model: tao-8k +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 1 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py new file mode 100644 index 0000000000..10ac1a1861 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -0,0 +1,184 @@ +import time +from abc import abstractmethod +from collections.abc import Mapping +from json import dumps +from typing import Any, Optional + +import numpy as np +from requests import Response, post + +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + invoke_error_mapping, +) + + +class TextEmbedding: + @abstractmethod + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + raise NotImplementedError + + +class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + access_token = self._get_access_token() + url = f'{self.api_bases[model]}?access_token={access_token}' + body = self._build_embed_request_body(model, texts, user) + headers = { + 'Content-Type': 'application/json', + } + + resp = post(url, data=dumps(body), headers=headers) + if resp.status_code != 200: + raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + return self._handle_embed_response(model, resp) + + def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: + if len(texts) == 0: + raise BadRequestError('The number of texts should not be zero.') + body = { + 'input': texts, + 'user_id': user, + } + return body + + def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + data = response.json() + if 'error_code' in data: + code = data['error_code'] + msg = data['error_msg'] + # raise error + self._handle_error(code, msg) + + embeddings = [v['embedding'] for v in data['data']] + _usage = data['usage'] + tokens = _usage['prompt_tokens'] + total_tokens = _usage['total_tokens'] + + return embeddings, tokens, total_tokens + + +class WenxinTextEmbeddingModel(TextEmbeddingModel): + def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: + return WenxinTextEmbedding(api_key, secret_key) + + def _invoke(self, model: str, credentials: dict, texts: list[str], + user: Optional[str] = None) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + + api_key = credentials['api_key'] + secret_key = credentials['secret_key'] + embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) + user = user if user else 'ErnieBotDefault' + + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + used_total_tokens = 0 + + for i, text in enumerate(texts): + + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + for i in _iter: + embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( + model, + inputs[i: i + max_chunks], + user) + used_tokens += _used_tokens + used_total_tokens += _total_used_tokens + batched_embeddings += embeddings_batch + + usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens) + return TextEmbeddingResult( + model=model, + embeddings=batched_embeddings, + usage=usage, + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + total_num_tokens = 0 + for text in texts: + total_num_tokens += self._get_num_tokens_by_gpt2(text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + api_key = credentials['api_key'] + secret_key = credentials['secret_key'] + try: + BaiduAccessToken.get_access_token(api_key, secret_key) + except Exception as e: + raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return invoke_error_mapping() + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=total_tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index b3a1f60824..6a6b38e6a1 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -17,6 +17,7 @@ help: en_US: https://cloud.baidu.com/wenxin.html supported_model_types: - llm + - text-embedding configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py new file mode 100644 index 0000000000..0fbd0f55ec --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -0,0 +1,57 @@ +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ], + InvokeServerUnavailableError: [ + InternalServerError + ], + InvokeRateLimitError: [ + RateLimitReachedError + ], + InvokeAuthorizationError: [ + InvalidAuthenticationError, + InsufficientAccountBalance, + InvalidAPIKeyError, + ], + InvokeBadRequestError: [ + BadRequestError, + KeyError + ] + } + + +class InvalidAuthenticationError(Exception): + pass + +class InvalidAPIKeyError(Exception): + pass + +class RateLimitReachedError(Exception): + pass + +class InsufficientAccountBalance(Exception): + pass + +class InternalServerError(Exception): + pass + +class BadRequestError(Exception): + pass \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 988bb0ce44..4760e8f118 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): tools=tools, stop=stop, stream=stream, user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) ) @@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'completion_type' not in credentials: if 'chat' in extra_param.model_ability: @@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): else: extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'chat' in extra_args.model_ability: @@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): xinference_client = Client( base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_model = xinference_client.get_model(credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 4e7543fd99..d809537479 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel): # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 9ee3621317..62b77f22e5 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel): # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 11f1e29cb3..3a8d704c25 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): server_url = credentials['server_url'] model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + api_key = credentials.get('api_key') + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=server_url, + model_uid=model_uid, + api_key=api_key, + ) if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) + client = Client( + base_url=server_url, + api_key=api_key, + ) try: handle = client.get_model(model_uid=model_uid) diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index a564a021b1..bfa752df8c 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel): extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) if 'text-to-audio' not in extra_param.model_ability: @@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel): credentials['server_url'] = credentials['server_url'][:-1] try: - handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) + api_key = credentials.get('api_key') + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + handle = RESTfulAudioModelHandle( + credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + ) model_support_voice = [x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials)] diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 7db483a485..75161ad376 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -35,13 +35,13 @@ cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: cache[model_uid] = { 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) } return cache[model_uid]['value'] @@ -56,7 +56,7 @@ class XinferenceHelper: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ @@ -70,9 +70,10 @@ class XinferenceHelper: session = Session() session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3)) + headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') if response.status_code != 200: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 452b270348..fd7ed0181b 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,11 +1,10 @@ import enum import json import os -from typing import Optional +from typing import TYPE_CHECKING, Optional from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class ModelMode(enum.Enum): COMPLETION = 'completion' @@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform): prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> \ @@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform): return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] for file in files: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c0b3746e18..6c68cee7be 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -5,6 +5,7 @@ from typing import Optional from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -18,12 +19,9 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.helper.position_helper import is_filtered from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormType, - ProviderEntity, -) +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db @@ -45,6 +43,7 @@ class ProviderManager: """ ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ + def __init__(self) -> None: self.decoding_rsa_key = None self.decoding_cipher_rsa = None @@ -117,6 +116,16 @@ class ProviderManager: # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: + + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + continue + provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) @@ -271,6 +280,24 @@ class ProviderManager: ) ) + def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Get names of first model and its provider + + :param tenant_id: workspace id + :param model_type: model type + :return: provider name, model name + """ + provider_configurations = self.get_configurations(tenant_id) + + # get available models from provider_configurations + all_models = provider_configurations.get_models( + model_type=model_type, + only_active=False + ) + + return all_models[0].provider.provider, all_models[0].model + def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ -> TenantDefaultModel: """ @@ -811,7 +838,7 @@ class ProviderManager: -> list[ModelSettings]: """ Convert to model settings. - + :param provider_entity: provider entity :param provider_model_settings: provider model settings include enabled, load balancing enabled :param load_balancing_model_configs: load balancing model configs :return: diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 33ca5bc028..c9f2f35af0 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -152,8 +152,27 @@ class PGVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # do not support bm25 search - return [] + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cur: + cur.execute( + f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), to_tsquery(%s)) AS score + FROM {self.table_name} + WHERE to_tsvector(text) @@ plainto_tsquery(%s) + ORDER BY score DESC + LIMIT {top_k}""", + # f"'{query}'" is required in order to account for whitespace in query + (f"'{query}'", f"'{query}'"), + ) + + docs = [] + + for record in cur: + metadata, text, score = record + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + + return docs def delete(self) -> None: with self._get_cursor() as cur: diff --git a/api/core/tools/docs/zh_Hans/advanced_scale_out.md b/api/core/tools/docs/zh_Hans/advanced_scale_out.md index 93f81b033d..0385dfe4e7 100644 --- a/api/core/tools/docs/zh_Hans/advanced_scale_out.md +++ b/api/core/tools/docs/zh_Hans/advanced_scale_out.md @@ -21,6 +21,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create an image message :param image: the url of the image + :param save_as: save as :return: the image message """ ``` @@ -34,6 +35,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create a link message :param link: the url of the link + :param save_as: save as :return: the link message """ ``` @@ -47,6 +49,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create a text message :param text: the text of the message + :param save_as: save as :return: the text message """ ``` @@ -63,6 +66,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型 create a blob message :param blob: the blob + :param meta: meta + :param save_as: save as :return: the blob message """ ``` diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4..062668fc5b 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.helper.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_tool_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) def name_func(provider: UserToolProvider) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/provider/builtin/crossref/_assets/icon.svg b/api/core/tools/provider/builtin/crossref/_assets/icon.svg new file mode 100644 index 0000000000..aa629de7cb --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/_assets/icon.svg @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py new file mode 100644 index 0000000000..404e483e0d --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -0,0 +1,20 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class CrossRefProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + CrossRefQueryDOITool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "doi": '10.1007/s00894-022-05373-8', + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/crossref/crossref.yaml b/api/core/tools/provider/builtin/crossref/crossref.yaml new file mode 100644 index 0000000000..da67fbec3a --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/crossref.yaml @@ -0,0 +1,29 @@ +identity: + author: Sakura4036 + name: crossref + label: + en_US: CrossRef + zh_Hans: CrossRef + description: + en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers. + zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接,使得读者能够非常便捷地获取文献全文。 + icon: icon.svg + tags: + - search +credentials_for_provider: + mailto: + type: text-input + required: true + label: + en_US: email address + zh_Hans: email地址 + pt_BR: email address + placeholder: + en_US: Please input your email address + zh_Hans: 请输入你的email地址 + pt_BR: Please input your email address + help: + en_US: According to the requirements of Crossref, an email address is required + zh_Hans: 根据Crossref的要求,需要提供一个邮箱地址 + pt_BR: According to the requirements of Crossref, an email address is required + url: https://api.crossref.org/swagger-ui/index.html diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py new file mode 100644 index 0000000000..a43c0989e4 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -0,0 +1,25 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class CrossRefQueryDOITool(BuiltinTool): + """ + Tool for querying the metadata of a publication using its DOI. + """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get('doi') + if not doi: + raise ToolParameterValidationError('doi is required.') + # doc: https://github.com/CrossRef/rest-api-doc + url = f"https://api.crossref.org/works/{doi}" + response = requests.get(url) + response.raise_for_status() + response = response.json() + message = response.get('message', {}) + + return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml b/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml new file mode 100644 index 0000000000..9c16da25ed --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml @@ -0,0 +1,23 @@ +identity: + name: crossref_query_doi + author: Sakura4036 + label: + en_US: CrossRef Query DOI + zh_Hans: CrossRef DOI 查询 + pt_BR: CrossRef Query DOI +description: + human: + en_US: A tool for searching literature information using CrossRef by DOI. + zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。 + pt_BR: A tool for searching literature information using CrossRef by DOI. + llm: A tool for searching literature information using CrossRef by DOI. +parameters: + - name: doi + type: string + required: true + label: + en_US: DOI + zh_Hans: DOI + pt_BR: DOI + llm_description: DOI for searching in CrossRef + form: llm diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py new file mode 100644 index 0000000000..946aa6dc94 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -0,0 +1,120 @@ +import time +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +def convert_time_str_to_seconds(time_str: str) -> int: + """ + Convert a time string to seconds. + example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 + """ + time_str = time_str.lower().strip().replace(' ', '') + seconds = 0 + if 'h' in time_str: + hours, time_str = time_str.split('h') + seconds += int(hours) * 3600 + if 'm' in time_str: + minutes, time_str = time_str.split('m') + seconds += int(minutes) * 60 + if 's' in time_str: + seconds += int(time_str.replace('s', '')) + return seconds + + +class CrossRefQueryTitleAPI: + """ + Tool for querying the metadata of a publication using its title. + Crossref API doc: https://github.com/CrossRef/rest-api-doc + """ + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" + rate_limit: int = 50 + rate_interval: float = 1 + max_limit: int = 1000 + + def __init__(self, mailto: str): + self.mailto = mailto + + def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + """ + Query the metadata of a publication using its title. + :param query: the title of the publication + :param rows: the number of results to return + :param sort: the sort field + :param order: the sort order + :param fuzzy_query: whether to return all items that match the query + """ + url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto) + response = requests.get(url) + response.raise_for_status() + rate_limit = int(response.headers['x-ratelimit-limit']) + # convert time string to seconds + rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval']) + + self.rate_limit = rate_limit + self.rate_interval = rate_interval + + response = response.json() + if response['status'] != 'ok': + return [] + + message = response['message'] + if fuzzy_query: + # fuzzy query return all items + return message['items'] + else: + for paper in message['items']: + title = paper['title'][0] + if title.lower() != query.lower(): + continue + return [paper] + return [] + + def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]: + """ + Query the metadata of a publication using its title. + :param query: the title of the publication + :param rows: the number of results to return + :param sort: the sort field + :param order: the sort order + :param fuzzy_query: whether to return all items that match the query + """ + rows = min(rows, self.max_limit) + if rows > self.rate_limit: + # query multiple times + query_times = rows // self.rate_limit + 1 + results = [] + + for i in range(query_times): + result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query) + if fuzzy_query: + results.extend(result) + else: + # fuzzy_query=False, only one result + if result: + return result + time.sleep(self.rate_interval) + return results + else: + # query once + return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query) + + +class CrossRefQueryTitleTool(BuiltinTool): + """ + Tool for querying the metadata of a publication using its title. + """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get('query') + fuzzy_query = tool_parameters.get('fuzzy_query', False) + rows = tool_parameters.get('rows', 3) + sort = tool_parameters.get('sort', 'relevance') + order = tool_parameters.get('order', 'desc') + mailto = self.runtime.credentials['mailto'] + + result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) + + return [self.create_json_message(r) for r in result] diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.yaml b/api/core/tools/provider/builtin/crossref/tools/query_title.yaml new file mode 100644 index 0000000000..5579c77f52 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.yaml @@ -0,0 +1,105 @@ +identity: + name: crossref_query_title + author: Sakura4036 + label: + en_US: CrossRef Title Query + zh_Hans: CrossRef 标题查询 + pt_BR: CrossRef Title Query +description: + human: + en_US: A tool for querying literature information using CrossRef by title. + zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。 + pt_BR: A tool for querying literature information using CrossRef by title. + llm: A tool for querying literature information using CrossRef by title. +parameters: + - name: query + type: string + required: true + label: + en_US: 标题 + zh_Hans: 查询语句 + pt_BR: 标题 + human_description: + en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years + zh_Hans: 用于搜索文献信息,有助于查找引用。包括标题,作者,ISSN和出版年份 + pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years + llm_description: key words for querying in Web of Science + form: llm + - name: fuzzy_query + type: boolean + default: false + label: + en_US: Whether to fuzzy search + zh_Hans: 是否模糊搜索 + pt_BR: Whether to fuzzy search + human_description: + en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none + zh_Hans: 用于选择搜索类型,模糊搜索返回更多结果,精确搜索返回1条结果或无 + pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none + form: form + - name: limit + type: number + required: false + label: + en_US: max query number + zh_Hans: 最大搜索数 + pt_BR: max query number + human_description: + en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches) + zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数) + pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches) + form: llm + default: 50 + - name: sort + type: select + required: true + options: + - value: relevance + label: + en_US: relevance + zh_Hans: 相关性 + pt_BR: relevance + - value: published + label: + en_US: publication date + zh_Hans: 出版日期 + pt_BR: publication date + - value: references-count + label: + en_US: references-count + zh_Hans: 引用次数 + pt_BR: references-count + default: relevance + label: + en_US: sorting field + zh_Hans: 排序字段 + pt_BR: sorting field + human_description: + en_US: Sorting of query results + zh_Hans: 检索结果的排序字段 + pt_BR: Sorting of query results + form: form + - name: order + type: select + required: true + options: + - value: desc + label: + en_US: descending + zh_Hans: 降序 + pt_BR: descending + - value: asc + label: + en_US: ascending + zh_Hans: 升序 + pt_BR: ascending + default: desc + label: + en_US: Order + zh_Hans: 排序 + pt_BR: Order + human_description: + en_US: Order of query results + zh_Hans: 检索结果的排序方式 + pt_BR: Order of query results + form: form diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py index fca34ae15f..0c13ec662a 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.py +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -29,6 +29,6 @@ class GitlabProvider(BuiltinToolProviderController): if response.status_code != 200: raise ToolProviderCredentialValidationError((response.json()).get('message')) except Exception as e: - raise ToolProviderCredentialValidationError("Gitlab Access Tokens and Api Version is invalid. {}".format(e)) + raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.yaml b/api/core/tools/provider/builtin/gitlab/gitlab.yaml index b5feea2382..22d7ebf73a 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.yaml +++ b/api/core/tools/provider/builtin/gitlab/gitlab.yaml @@ -2,37 +2,37 @@ identity: author: Leo.Wang name: gitlab label: - en_US: Gitlab - zh_Hans: Gitlab + en_US: GitLab + zh_Hans: GitLab description: - en_US: Gitlab plugin for commit - zh_Hans: 用于获取Gitlab commit的插件 + en_US: GitLab plugin, API v4 only. + zh_Hans: 用于获取GitLab内容的插件,目前仅支持 API v4。 icon: gitlab.svg credentials_for_provider: access_tokens: type: secret-input required: true label: - en_US: Gitlab access token - zh_Hans: Gitlab access token + en_US: GitLab access token + zh_Hans: GitLab access token placeholder: - en_US: Please input your Gitlab access token - zh_Hans: 请输入你的 Gitlab access token + en_US: Please input your GitLab access token + zh_Hans: 请输入你的 GitLab access token help: - en_US: Get your Gitlab access token from Gitlab - zh_Hans: 从 Gitlab 获取您的 access token + en_US: Get your GitLab access token from GitLab + zh_Hans: 从 GitLab 获取您的 access token url: https://docs.gitlab.com/16.9/ee/api/oauth2.html site_url: type: text-input required: false default: 'https://gitlab.com' label: - en_US: Gitlab site url - zh_Hans: Gitlab site url + en_US: GitLab site url + zh_Hans: GitLab site url placeholder: - en_US: Please input your Gitlab site url - zh_Hans: 请输入你的 Gitlab site url + en_US: Please input your GitLab site url + zh_Hans: 请输入你的 GitLab site url help: - en_US: Find your Gitlab url - zh_Hans: 找到你的Gitlab url + en_US: Find your GitLab url + zh_Hans: 找到你的 GitLab url url: https://gitlab.com/help diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index 212bdb03ab..880d722bda 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -18,6 +18,7 @@ class GitlabCommitsTool(BuiltinTool): employee = tool_parameters.get('employee', '') start_time = tool_parameters.get('start_time', '') end_time = tool_parameters.get('end_time', '') + change_type = tool_parameters.get('change_type', 'all') if not project: return self.create_text_message('Project is required') @@ -36,11 +37,11 @@ class GitlabCommitsTool(BuiltinTool): site_url = 'https://gitlab.com' # Get commit content - result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time) + result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) - return self.create_text_message(json.dumps(result, ensure_ascii=False)) + return [self.create_json_message(item) for item in result] - def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '') -> list[dict[str, Any]]: + def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] @@ -74,7 +75,7 @@ class GitlabCommitsTool(BuiltinTool): for commit in commits: commit_sha = commit['id'] - print(f"\tCommit SHA: {commit_sha}") + author_name = commit['author_name'] diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" diff_response = requests.get(diff_url, headers=headers) @@ -87,14 +88,23 @@ class GitlabCommitsTool(BuiltinTool): removed_lines = diff['diff'].count('\n-') total_changes = added_lines + removed_lines - if total_changes > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) - results.append({ - "project": project_name, - "commit_sha": commit_sha, - "diff": final_code - }) - print(f"Commit code:{final_code}") + if change_type == "new": + if added_lines > 1: + final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) + results.append({ + "commit_sha": commit_sha, + "author_name": author_name, + "diff": final_code + }) + else: + if total_changes > 1: + final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append({ + "commit_sha": commit_sha, + "author_name": author_name, + "diff": final_code_escaped + }) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml index fc4e7eb7bb..dd4e31d663 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -2,24 +2,24 @@ identity: name: gitlab_commits author: Leo.Wang label: - en_US: Gitlab Commits - zh_Hans: Gitlab代码提交内容 + en_US: GitLab Commits + zh_Hans: GitLab 提交内容查询 description: human: - en_US: A tool for query gitlab commits. Input should be a exists username. - zh_Hans: 一个用于查询gitlab代码提交记录的的工具,输入的内容应该是一个已存在的用户名或者项目名。 - llm: A tool for query gitlab commits. Input should be a exists username or project. + en_US: A tool for query GitLab commits, Input should be a exists username or projec. + zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。 + llm: A tool for query GitLab commits, Input should be a exists username or project. parameters: - - name: employee + - name: username type: string required: false label: - en_US: employee + en_US: username zh_Hans: 员工用户名 human_description: - en_US: employee + en_US: username zh_Hans: 员工用户名 - llm_description: employee for gitlab + llm_description: User name for GitLab form: llm - name: project type: string @@ -30,7 +30,7 @@ parameters: human_description: en_US: project zh_Hans: 项目名 - llm_description: project for gitlab + llm_description: project for GitLab form: llm - name: start_time type: string @@ -41,7 +41,7 @@ parameters: human_description: en_US: start_time zh_Hans: 开始时间 - llm_description: start_time for gitlab + llm_description: Start time for GitLab form: llm - name: end_time type: string @@ -52,5 +52,26 @@ parameters: human_description: en_US: end_time zh_Hans: 结束时间 - llm_description: end_time for gitlab + llm_description: End time for GitLab + form: llm + - name: change_type + type: select + required: false + options: + - value: all + label: + en_US: all + zh_Hans: 所有 + - value: new + label: + en_US: new + zh_Hans: 新增 + default: all + label: + en_US: change_type + zh_Hans: 变更类型 + human_description: + en_US: change_type + zh_Hans: 变更类型 + llm_description: Content change type for GitLab form: llm diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py new file mode 100644 index 0000000000..7fa1d0d112 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -0,0 +1,95 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabFilesTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + project = tool_parameters.get('project', '') + branch = tool_parameters.get('branch', '') + path = tool_parameters.get('path', '') + + + if not project: + return self.create_text_message('Project is required') + if not branch: + return self.create_text_message('Branch is required') + + if not path: + return self.create_text_message('Path is required') + + access_token = self.runtime.credentials.get('access_tokens') + site_url = self.runtime.credentials.get('site_url') + + if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + return self.create_text_message("Gitlab API Access Tokens is required.") + if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): + site_url = 'https://gitlab.com' + + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, project) + if not project_id: + return self.create_text_message(f"Project '{project}' not found.") + + # Get commit content + result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + + return [self.create_json_message(item) for item in result] + + def extract_project_name_and_path(self, path: str) -> tuple[str, str]: + parts = path.split('/', 1) + if len(parts) < 2: + return None, None + return parts[0], parts[1] + + def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: + headers = {"PRIVATE-TOKEN": access_token} + try: + url = f"{site_url}/api/v4/projects?search={project_name}" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + for project in projects: + if project['name'] == project_name: + return project['id'] + except requests.RequestException as e: + print(f"Error fetching project ID from GitLab: {e}") + return None + + def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + # List files and directories in the given path + url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + response = requests.get(url, headers=headers) + response.raise_for_status() + items = response.json() + + for item in items: + item_path = item['path'] + if item['type'] == 'tree': # It's a directory + results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) + else: # It's a file + file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + file_response = requests.get(file_url, headers=headers) + file_response.raise_for_status() + file_content = file_response.text + results.append({ + "path": item_path, + "branch": branch, + "content": file_content + }) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml new file mode 100644 index 0000000000..d99b6254c1 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -0,0 +1,45 @@ +identity: + name: gitlab_files + author: Leo.Wang + label: + en_US: GitLab Files + zh_Hans: GitLab 文件获取 +description: + human: + en_US: A tool for query GitLab files, Input should be branch and a exists file or directory path. + zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 + llm: A tool for query GitLab files, Input should be a exists file or directory path. +parameters: + - name: project + type: string + required: true + label: + en_US: project + zh_Hans: 项目 + human_description: + en_US: project + zh_Hans: 项目 + llm_description: Project for GitLab + form: llm + - name: branch + type: string + required: true + label: + en_US: branch + zh_Hans: 分支 + human_description: + en_US: branch + zh_Hans: 分支 + llm_description: Branch for GitLab + form: llm + - name: path + type: string + required: true + label: + en_US: path + zh_Hans: 文件路径 + human_description: + en_US: path + zh_Hans: 文件路径 + llm_description: File path for GitLab + form: llm diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py new file mode 100644 index 0000000000..0d018e3ca2 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -0,0 +1,43 @@ +from typing import Any + +from core.helper import ssrf_proxy +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class JinaTokenizerTool(BuiltinTool): + _jina_tokenizer_endpoint = 'https://tokenize.jina.ai/' + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> ToolInvokeMessage: + content = tool_parameters['content'] + body = { + "content": content + } + + headers = { + 'Content-Type': 'application/json' + } + + if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): + headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + + if tool_parameters.get('return_chunks', False): + body['return_chunks'] = True + + if tool_parameters.get('return_tokens', False): + body['return_tokens'] = True + + if tokenizer := tool_parameters.get('tokenizer'): + body['tokenizer'] = tokenizer + + response = ssrf_proxy.post( + self._jina_tokenizer_endpoint, + headers=headers, + json=body, + ) + + return self.create_json_message(response.json()) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml new file mode 100644 index 0000000000..62a5c7e7ba --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml @@ -0,0 +1,70 @@ +identity: + name: jina_tokenizer + author: hjlarry + label: + en_US: JinaTokenizer +description: + human: + en_US: Free API to tokenize text and segment long text into chunks. + zh_Hans: 免费的API可以将文本tokenize,也可以将长文本分割成多个部分。 + llm: Free API to tokenize text and segment long text into chunks. +parameters: + - name: content + type: string + required: true + label: + en_US: Content + zh_Hans: 内容 + llm_description: the content which need to tokenize or segment + form: llm + - name: return_tokens + type: boolean + required: false + label: + en_US: Return the tokens + zh_Hans: 是否返回tokens + human_description: + en_US: Return the tokens and their corresponding ids in the response. + zh_Hans: 返回tokens及其对应的ids。 + form: form + - name: return_chunks + type: boolean + label: + en_US: Return the chunks + zh_Hans: 是否分块 + human_description: + en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues. + zh_Hans: 将输入分块为具有语义意义的片段,同时根据常见的结构线索处理各种文本类型和边缘情况。 + form: form + - name: tokenizer + type: select + options: + - value: cl100k_base + label: + en_US: cl100k_base + - value: o200k_base + label: + en_US: o200k_base + - value: p50k_base + label: + en_US: p50k_base + - value: r50k_base + label: + en_US: r50k_base + - value: p50k_edit + label: + en_US: p50k_edit + - value: gpt2 + label: + en_US: gpt2 + label: + en_US: Tokenizer + human_description: + en_US: | + · cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5 + · o200k_base --- gpt-4o, gpt-4o-mini + · p50k_base --- text-davinci-003, text-davinci-002 + · r50k_base --- text-davinci-001, text-curie-001 + · p50k_edit --- text-davinci-edit-001, code-davinci-edit-001 + · gpt2 --- gpt-2 + form: form diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f7911fea1d..f14abac767 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import VariableEntity +from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -18,6 +18,13 @@ from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow +VARIABLE_TO_PARAMETER_TYPE_MAPPING = { + VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, + VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, + VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, + VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, +} + class WorkflowToolProviderController(ToolProviderController): provider_id: str @@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController): if not app: raise ValueError('app not found') - + controller = WorkflowToolProviderController(**{ 'identity': { 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', @@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController): 'credentials_schema': {}, 'provider_id': db_provider.id or '', }) - + # init tools controller.tools = [controller._get_db_provider_tool(db_provider, app)] @@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ get db provider tool @@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController): if variable: parameter_type = None options = None - if variable.type in [ - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.PARAGRAPH, - ]: - parameter_type = ToolParameter.ToolParameterType.STRING - elif variable.type in [ - VariableEntity.Type.SELECT - ]: - parameter_type = ToolParameter.ToolParameterType.SELECT - elif variable.type in [ - VariableEntity.Type.NUMBER - ]: - parameter_type = ToolParameter.ToolParameterType.NUMBER - else: + if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: raise ValueError(f'unsupported variable type {variable.type}') - - if variable.type == VariableEntity.Type.SELECT and variable.options: + parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] + + if variable.type == VariableEntityType.SELECT and variable.options: options = [ ToolParameterOption( value=option, @@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController): """ if self.tools is not None: return self.tools - + db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, @@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController): if not db_providers: return [] - + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ get tool by name @@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController): for tool in self.tools: if tool.identity.name == tool_name: return tool - + return None diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 5d561911d1..d990131b5f 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -2,13 +2,12 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from copy import deepcopy from enum import Enum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator from pydantic_core.core_schema import ValidationInfo from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileVar from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, @@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import ( from core.tools.tool_file_manager import ToolFileManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class Tool(BaseModel, ABC): identity: Optional[ToolIdentity] = None @@ -76,7 +78,7 @@ class Tool(BaseModel, ABC): description=self.description.model_copy() if self.description else None, runtime=Tool.Runtime(**runtime), ) - + @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ @@ -84,7 +86,7 @@ class Tool(BaseModel, ABC): :return: the tool provider type """ - + def load_variables(self, variables: ToolRuntimeVariablePool): """ load variables from database @@ -99,7 +101,7 @@ class Tool(BaseModel, ABC): """ if not self.variables: return - + self.variables.set_file(self.identity.name, variable_name, image_key) def set_text_variable(self, variable_name: str, text: str) -> None: @@ -108,9 +110,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return - + self.variables.set_text(self.identity.name, variable_name, text) - + def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ get a variable @@ -120,14 +122,14 @@ class Tool(BaseModel, ABC): """ if not self.variables: return None - + if isinstance(name, Enum): name = name.value - + for variable in self.variables.pool: if variable.name == name: return variable - + return None def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: @@ -138,9 +140,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return None - + return self.get_variable(self.VARIABLE_KEY.IMAGE) - + def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ get a variable file @@ -151,7 +153,7 @@ class Tool(BaseModel, ABC): variable = self.get_variable(name) if not variable: return None - + if not isinstance(variable, ToolRuntimeImageVariable): return None @@ -160,9 +162,9 @@ class Tool(BaseModel, ABC): file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) if not file_binary: return None - + return file_binary[0] - + def list_variables(self) -> list[ToolRuntimeVariable]: """ list all variables @@ -171,9 +173,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return [] - + return self.variables.pool - + def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ list all image variables @@ -182,9 +184,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return [] - + result = [] - + for variable in self.variables.pool: if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): result.append(variable) @@ -225,7 +227,7 @@ class Tool(BaseModel, ABC): @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - + def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ validate the credentials @@ -244,7 +246,7 @@ class Tool(BaseModel, ABC): :return: the runtime parameters """ return self.parameters or [] - + def get_all_runtime_parameters(self) -> list[ToolParameter]: """ get all runtime parameters @@ -278,7 +280,7 @@ class Tool(BaseModel, ABC): parameters.append(parameter) return parameters - + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: """ create an image message @@ -286,18 +288,18 @@ class Tool(BaseModel, ABC): :param image: the url of the image :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, + message=image, save_as=save_as) - - def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: + + def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, message='', meta={ 'file_var': file_var }, save_as='') - + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: """ create a link message @@ -305,10 +307,10 @@ class Tool(BaseModel, ABC): :param link: the url of the link :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, + message=link, save_as=save_as) - + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: """ create a text message @@ -321,7 +323,7 @@ class Tool(BaseModel, ABC): message=text, save_as=save_as ) - + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d7ddb40e6b..4a0188af49 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -10,14 +10,11 @@ from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, -) +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort @@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolConfigurationManager, - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db @@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) + class ToolManager: _builtin_provider_lock = Lock() _builtin_providers = {} @@ -107,7 +102,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -346,7 +341,7 @@ class ToolManager: provider_class = load_single_subclass_from_source( module_name=f'core.tools.provider.builtin.{provider}.{provider}', script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), + 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider @@ -414,6 +409,15 @@ class ToolManager: # append builtin providers for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name + ): + continue + user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), @@ -473,7 +477,7 @@ class ToolManager: @classmethod def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + ApiToolProviderController, dict[str, Any]]: """ get the api provider @@ -593,4 +597,5 @@ class ToolManager: else: raise ValueError(f"provider type {provider_type} not found") + ToolManager.load_builtin_providers_cache() diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ef9e5b67ae..564b9d3e14 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,7 @@ import logging from mimetypes import guess_extension -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -27,12 +27,12 @@ class ToolFileMessageTransformer: # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) - + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' result.append(ToolInvokeMessage( @@ -55,14 +55,14 @@ class ToolFileMessageTransformer: # if message is str, encode it to bytes if isinstance(message.message, str): message.message = message.message.encode('utf-8') - + file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, mimetype=mimetype ) - + url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image @@ -81,7 +81,7 @@ class ToolFileMessageTransformer: meta=message.meta.copy() if message.meta is not None else {}, )) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var: FileVar = message.meta.get('file_var') + file_var = message.meta.get('file_var') if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) @@ -103,7 +103,7 @@ class ToolFileMessageTransformer: result.append(message) return result - + @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: - return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file + return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 9fe3356faa..8120b2ac78 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -6,20 +6,20 @@ from typing_extensions import deprecated from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey VariableValue = Union[str, int, float, dict, list, FileVar] -SYSTEM_VARIABLE_NODE_ID = 'sys' -ENVIRONMENT_VARIABLE_NODE_ID = 'env' -CONVERSATION_VARIABLE_NODE_ID = 'conversation' +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" class VariablePool: def __init__( self, - system_variables: Mapping[SystemVariable, Any], + system_variables: Mapping[SystemVariableKey, Any], user_inputs: Mapping[str, Any], environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable] | None = None, @@ -68,7 +68,7 @@ class VariablePool: None """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") if value is None: return @@ -95,13 +95,13 @@ class VariablePool: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value - @deprecated('This method is deprecated, use `get` instead.') + @deprecated("This method is deprecated, use `get` instead.") def get_any(self, selector: Sequence[str], /) -> Any | None: """ Retrieves the value from the variable pool based on the given selector. @@ -116,7 +116,7 @@ class VariablePool: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') + raise ValueError("Invalid selector") hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) return value.to_object() if value else None diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 4757cf32f8..da65f6b1fb 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,25 +1,13 @@ from enum import Enum -class SystemVariable(str, Enum): +class SystemVariableKey(str, Enum): """ System Variables. """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - DIALOGUE_COUNT = 'dialogue_count' - @classmethod - def value_of(cls, value: str): - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 60678bc2ba..4a010623f9 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -14,7 +14,7 @@ from models.workflow import WorkflowNodeExecutionStatus MAX_NUMBER = dify_config.CODE_MAX_NUMBER MIN_NUMBER = dify_config.CODE_MIN_NUMBER MAX_PRECISION = 20 -MAX_DEPTH = 5 +MAX_DEPTH = dify_config.CODE_MAX_DEPTH MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 97b64d4b05..c3e4949421 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,14 +1,13 @@ import json from collections.abc import Generator from copy import deepcopy -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage @@ -25,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import ( LLMNodeChatModelMessage, @@ -39,6 +38,10 @@ from models.model import Conversation from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.file.file_obj import FileVar + + class LLMNode(BaseNode): _node_data_cls = LLMNodeData @@ -71,7 +74,7 @@ class LLMNode(BaseNode): node_inputs = {} # fetch files - files: list[FileVar] = self._fetch_files(node_data, variable_pool) + files = self._fetch_files(node_data, variable_pool) if files: node_inputs['#files#'] = [file.to_dict() for file in files] @@ -91,7 +94,7 @@ class LLMNode(BaseNode): # fetch prompt messages prompt_messages, stop = self._fetch_prompt_messages( node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) + query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value]) if node_data.memory else None, query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, inputs=inputs, @@ -322,7 +325,7 @@ class LLMNode(BaseNode): return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: """ Fetch files :param node_data: node data @@ -332,7 +335,7 @@ class LLMNode(BaseNode): if not node_data.vision.enabled: return [] - files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) + files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value]) if not files: return [] @@ -497,7 +500,7 @@ class LLMNode(BaseNode): return None # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) + conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value]) if conversation_id is None: return None @@ -521,7 +524,7 @@ class LLMNode(BaseNode): query: Optional[str], query_prompt_template: Optional[str], inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ @@ -669,10 +672,10 @@ class LLMNode(BaseNode): variable_mapping['#context#'] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] + variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value] if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] + variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value] if node_data.prompt_config: enable_jinja = False diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 0bd5f203bf..b81ce15bd7 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,3 +1,7 @@ +from collections.abc import Sequence + +from pydantic import Field + from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ - variables: list[VariableEntity] = [] + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 661b403d32..54e66bd671 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -17,16 +17,16 @@ class StartNode(BaseNode): :param variable_pool: variable pool :return: """ - # Get cleaned inputs - cleaned_inputs = dict(variable_pool.user_inputs) + node_inputs = dict(variable_pool.user_inputs) + system_inputs = variable_pool.system_variables - for var in variable_pool.system_variables: - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] + for var in system_inputs: + node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=cleaned_inputs, - outputs=cleaned_inputs + inputs=node_inputs, + outputs=node_inputs ) @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 554e3b6074..ccce9ef360 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from os import path from typing import Any, cast -from core.app.segments import ArrayAnyVariable, parser +from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -141,8 +141,8 @@ class ToolNode(BaseNode): return result def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - assert isinstance(variable, ArrayAnyVariable) + variable = variable_pool.get(['sys', SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index 552cc367f2..d791d51523 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -1,109 +1,8 @@ -from collections.abc import Sequence -from enum import Enum -from typing import Optional, cast +from .node import VariableAssignerNode +from .node_data import VariableAssignerData, WriteMode -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.segments import SegmentType, Variable, factory -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from extensions.ext_database import db -from models import ConversationVariable, WorkflowNodeExecutionStatus - - -class VariableAssignerNodeError(Exception): - pass - - -class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' - - -class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] - - -class VariableAssignerNode(BaseNode): - _node_data_cls: type[BaseNodeData] = VariableAssignerData - _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - data = cast(VariableAssignerData, self.node_data) - - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) - if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') - - match data.write_mode: - case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) - - case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) - - case WriteMode.CLEAR: - income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) - - case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') - - # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) - - # Update conversation variable. - # TODO: Find a better way to use the database. - conversation_id = variable_pool.get(['sys', 'conversation_id']) - if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') - update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - 'value': income_value.to_object(), - }, - ) - - -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') - row.data = variable.model_dump_json() - session.commit() - - -def get_zero_value(t: SegmentType): - match t: - case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return factory.build_segment([]) - case SegmentType.OBJECT: - return factory.build_segment({}) - case SegmentType.STRING: - return factory.build_segment('') - case SegmentType.NUMBER: - return factory.build_segment(0) - case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') +__all__ = [ + 'VariableAssignerNode', + 'VariableAssignerData', + 'WriteMode', +] diff --git a/api/core/workflow/nodes/variable_assigner/exc.py b/api/core/workflow/nodes/variable_assigner/exc.py new file mode 100644 index 0000000000..914be22256 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/exc.py @@ -0,0 +1,2 @@ +class VariableAssignerNodeError(Exception): + pass diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py new file mode 100644 index 0000000000..8c2adcabb9 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -0,0 +1,92 @@ +from typing import cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.segments import SegmentType, Variable, factory +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.base_node import BaseNode +from extensions.ext_database import db +from models import ConversationVariable, WorkflowNodeExecutionStatus + +from .exc import VariableAssignerNodeError +from .node_data import VariableAssignerData, WriteMode + + +class VariableAssignerNode(BaseNode): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + data = cast(VariableAssignerData, self.node_data) + + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = variable_pool.get(data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError('assigned variable not found') + + match data.write_mode: + case WriteMode.OVER_WRITE: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_variable = original_variable.model_copy(update={'value': income_value.value}) + + case WriteMode.APPEND: + income_value = variable_pool.get(data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError('input value not found') + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={'value': updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') + + # Over write the variable. + variable_pool.add(data.assigned_variable_selector, updated_variable) + + # TODO: Move database operation to the pipeline. + # Update conversation variable. + conversation_id = variable_pool.get(['sys', 'conversation_id']) + if not conversation_id: + raise VariableAssignerNodeError('conversation_id not found') + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'value': income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError('conversation variable not found in the database') + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return factory.build_segment([]) + case SegmentType.OBJECT: + return factory.build_segment({}) + case SegmentType.STRING: + return factory.build_segment('') + case SegmentType.NUMBER: + return factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f'unsupported variable type: {t}') diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py new file mode 100644 index 0000000000..b3652b6802 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -0,0 +1,19 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional + +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class WriteMode(str, Enum): + OVER_WRITE = 'over-write' + APPEND = 'append' + CLEAR = 'clear' + + +class VariableAssignerData(BaseNodeData): + title: str = 'Variable Assigner' + desc: Optional[str] = 'Assign a value to a variable' + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] diff --git a/api/events/app_event.py b/api/events/app_event.py index 67a5982527..f2ce71bbbb 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -1,13 +1,13 @@ from blinker import signal # sender: app -app_was_created = signal('app-was-created') +app_was_created = signal("app-was-created") # sender: app, kwargs: app_model_config -app_model_config_was_updated = signal('app-model-config-was-updated') +app_model_config_was_updated = signal("app-model-config-was-updated") # sender: app, kwargs: published_workflow -app_published_workflow_was_updated = signal('app-published-workflow-was-updated') +app_published_workflow_was_updated = signal("app-published-workflow-was-updated") # sender: app, kwargs: synced_draft_workflow -app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced') +app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") diff --git a/api/events/dataset_event.py b/api/events/dataset_event.py index d4a2b6f313..750b7424e2 100644 --- a/api/events/dataset_event.py +++ b/api/events/dataset_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: dataset -dataset_was_deleted = signal('dataset-was-deleted') +dataset_was_deleted = signal("dataset-was-deleted") diff --git a/api/events/document_event.py b/api/events/document_event.py index f95326630b..2c5a416a5e 100644 --- a/api/events/document_event.py +++ b/api/events/document_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_was_deleted = signal('document-was-deleted') +document_was_deleted = signal("document-was-deleted") diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py index 42f1c70614..7caa2d1cc9 100644 --- a/api/events/event_handlers/clean_when_dataset_deleted.py +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -5,5 +5,11 @@ from tasks.clean_dataset_task import clean_dataset_task @dataset_was_deleted.connect def handle(sender, **kwargs): dataset = sender - clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, - dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) + clean_dataset_task.delay( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + ) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index 24022da15f..00a66f50ad 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -5,7 +5,7 @@ from tasks.clean_document_task import clean_document_task @document_was_deleted.connect def handle(sender, **kwargs): document_id = sender - dataset_id = kwargs.get('dataset_id') - doc_form = kwargs.get('doc_form') - file_id = kwargs.get('file_id') + dataset_id = kwargs.get("dataset_id") + doc_form = kwargs.get("doc_form") + file_id = kwargs.get("file_id") clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 68dae5a553..72a135e73d 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,21 +14,25 @@ from models.dataset import Document @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get('document_ids', None) + document_ids = kwargs.get("document_ids", None) documents = [] start_at = time.perf_counter() for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document) + .filter( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) + .first() + ) if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -38,8 +42,8 @@ def handle(sender, **kwargs): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 31084ce0fe..57412cc4ad 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -10,7 +10,7 @@ def handle(sender, **kwargs): installed_app = InstalledApp( tenant_id=app.tenant_id, app_id=app.id, - app_owner_tenant_id=app.tenant_id + app_owner_tenant_id=app.tenant_id, ) db.session.add(installed_app) db.session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index f0eb7159b6..ab07c5d366 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -7,15 +7,16 @@ from models.model import Site def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender - account = kwargs.get('account') + account = kwargs.get("account") site = Site( app_id=app.id, title=app.name, - icon = app.icon, - icon_background = app.icon_background, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy='not_allow', - code=Site.generate_code(16) + customize_token_strategy="not_allow", + code=Site.generate_code(16), ) db.session.add(site) diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_messaeg_created.py index 8cf52bf8f5..843a232096 100644 --- a/api/events/event_handlers/deduct_quota_when_messaeg_created.py +++ b/api/events/event_handlers/deduct_quota_when_messaeg_created.py @@ -8,7 +8,7 @@ from models.provider import Provider, ProviderType @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return @@ -39,7 +39,7 @@ def handle(sender, **kwargs): elif quota_unit == QuotaUnit.CREDITS: used_quota = 1 - if 'gpt-4' in model_config.model: + if "gpt-4" in model_config.model: used_quota = 20 else: used_quota = 1 @@ -50,6 +50,6 @@ def handle(sender, **kwargs): Provider.provider_name == model_config.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 1f6da34ee2..f96bb5ef74 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,8 +8,8 @@ from events.app_event import app_draft_workflow_was_synced @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get('synced_draft_workflow').graph_dict.get('nodes', []): - if node_data.get('data', {}).get('type') == NodeType.TOOL.value: + for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( @@ -23,7 +23,7 @@ def handle(sender, **kwargs): tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, provider_type=tool_entity.provider_type, - identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}', ) manager.delete_tool_parameters_cache() except: diff --git a/api/events/event_handlers/document_index_event.py b/api/events/event_handlers/document_index_event.py index 9c4e055deb..3d463fe5b3 100644 --- a/api/events/event_handlers/document_index_event.py +++ b/api/events/event_handlers/document_index_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: document -document_index_created = signal('document-index-created') +document_index_created = signal("document-index-created") diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 2b202c53d0..59375b1a0b 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -7,13 +7,11 @@ from models.model import AppModelConfig @app_model_config_was_updated.connect def handle(sender, **kwargs): app = sender - app_model_config = kwargs.get('app_model_config') + app_model_config = kwargs.get("app_model_config") dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -29,16 +27,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -51,7 +45,7 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) or [] + tools = agent_mode.get("tools", []) or [] for tool in tools: if len(list(tool.keys())) != 1: continue @@ -63,11 +57,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict - datasets = dataset_configs.get('datasets', {}) or {} - for dataset in datasets.get('datasets', []) or []: + datasets = dataset_configs.get("datasets", {}) or {} + for dataset in datasets.get("datasets", []) or []: keys = list(dataset.keys()) - if len(keys) == 1 and keys[0] == 'dataset': - if dataset['dataset'].get('id'): - dataset_ids.add(dataset['dataset'].get('id')) + if len(keys) == 1 and keys[0] == "dataset": + if dataset["dataset"].get("id"): + dataset_ids.add(dataset["dataset"].get("id")) return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 996b1e9691..333b85ecb2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -11,13 +11,11 @@ from models.workflow import Workflow @app_published_workflow_was_updated.connect def handle(sender, **kwargs): app = sender - published_workflow = kwargs.get('published_workflow') + published_workflow = kwargs.get("published_workflow") published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id - ).all() + app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() removed_dataset_ids = [] if not app_dataset_joins: @@ -33,16 +31,12 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: db.session.query(AppDatasetJoin).filter( - AppDatasetJoin.app_id == app.id, - AppDatasetJoin.dataset_id == dataset_id + AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id ).delete() if added_dataset_ids: for dataset_id in added_dataset_ids: - app_dataset_join = AppDatasetJoin( - app_id=app.id, - dataset_id=dataset_id - ) + app_dataset_join = AppDatasetJoin(app_id=app.id, dataset_id=dataset_id) db.session.add(app_dataset_join) db.session.commit() @@ -54,18 +48,19 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: if not graph: return dataset_ids - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) # fetch all knowledge retrieval nodes - knowledge_retrieval_nodes = [node for node in nodes - if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value] + knowledge_retrieval_nodes = [ + node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value + ] if not knowledge_retrieval_nodes: return dataset_ids for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get('data', {})) + node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) dataset_ids.update(node_data.dataset_ids) except Exception as e: continue diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py index 6188f1a085..a80572c0de 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py @@ -9,13 +9,13 @@ from models.provider import Provider @message_was_created.connect def handle(sender, **kwargs): message = sender - application_generate_entity = kwargs.get('application_generate_entity') + application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): return db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, - Provider.provider_name == application_generate_entity.model_conf.provider - ).update({'last_used': datetime.now(timezone.utc).replace(tzinfo=None)}) + Provider.provider_name == application_generate_entity.model_conf.provider, + ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) db.session.commit() diff --git a/api/events/message_event.py b/api/events/message_event.py index 21da83f249..6576c35c45 100644 --- a/api/events/message_event.py +++ b/api/events/message_event.py @@ -1,4 +1,4 @@ from blinker import signal # sender: message, kwargs: conversation -message_was_created = signal('message-was-created') +message_was_created = signal("message-was-created") diff --git a/api/events/tenant_event.py b/api/events/tenant_event.py index 942f709917..d99feaac40 100644 --- a/api/events/tenant_event.py +++ b/api/events/tenant_event.py @@ -1,7 +1,7 @@ from blinker import signal # sender: tenant -tenant_was_created = signal('tenant-was-created') +tenant_was_created = signal("tenant-was-created") # sender: tenant -tenant_was_updated = signal('tenant-was-updated') +tenant_was_updated = signal("tenant-was-updated") diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index ae9a075340..f5ec7c1759 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -17,7 +17,7 @@ def init_app(app: Flask) -> Celery: backend=app.config["CELERY_BACKEND"], task_ignore_result=True, ) - + # Add SSL options to the Celery configuration ssl_options = { "ssl_cert_reqs": None, @@ -35,7 +35,7 @@ def init_app(app: Flask) -> Celery: celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) - + celery_app.set_default() app.extensions["celery"] = celery_app @@ -45,18 +45,15 @@ def init_app(app: Flask) -> Celery: ] day = app.config["CELERY_BEAT_SCHEDULER_TIME"] beat_schedule = { - 'clean_embedding_cache_task': { - 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', - 'schedule': timedelta(days=day), + "clean_embedding_cache_task": { + "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", + "schedule": timedelta(days=day), + }, + "clean_unused_datasets_task": { + "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", + "schedule": timedelta(days=day), }, - 'clean_unused_datasets_task': { - 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', - 'schedule': timedelta(days=day), - } } - celery_app.conf.update( - beat_schedule=beat_schedule, - imports=imports - ) + celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 1dbaffcfb0..38e67749fc 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -2,15 +2,14 @@ from flask import Flask def init_app(app: Flask): - if app.config.get('API_COMPRESSION_ENABLED'): + if app.config.get("API_COMPRESSION_ENABLED"): from flask_compress import Compress - app.config['COMPRESS_MIMETYPES'] = [ - 'application/json', - 'image/svg+xml', - 'text/html', + app.config["COMPRESS_MIMETYPES"] = [ + "application/json", + "image/svg+xml", + "text/html", ] compress = Compress() compress.init_app(app) - diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index c248e173a2..f6ffa53634 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -2,11 +2,11 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy import MetaData POSTGRES_INDEXES_NAMING_CONVENTION = { - 'ix': '%(column_0_label)s_idx', - 'uq': '%(table_name)s_%(column_0_name)s_key', - 'ck': '%(table_name)s_%(constraint_name)s_check', - 'fk': '%(table_name)s_%(column_0_name)s_fkey', - 'pk': '%(table_name)s_pkey', + "ix": "%(column_0_label)s_idx", + "uq": "%(table_name)s_%(column_0_name)s_key", + "ck": "%(table_name)s_%(constraint_name)s_check", + "fk": "%(table_name)s_%(column_0_name)s_fkey", + "pk": "%(table_name)s_pkey", } metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index ec3a5cc112..b435294abc 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -14,67 +14,69 @@ class Mail: return self._client is not None def init_app(self, app: Flask): - if app.config.get('MAIL_TYPE'): - if app.config.get('MAIL_DEFAULT_SEND_FROM'): - self._default_send_from = app.config.get('MAIL_DEFAULT_SEND_FROM') - - if app.config.get('MAIL_TYPE') == 'resend': - api_key = app.config.get('RESEND_API_KEY') - if not api_key: - raise ValueError('RESEND_API_KEY is not set') + if app.config.get("MAIL_TYPE"): + if app.config.get("MAIL_DEFAULT_SEND_FROM"): + self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") - api_url = app.config.get('RESEND_API_URL') + if app.config.get("MAIL_TYPE") == "resend": + api_key = app.config.get("RESEND_API_KEY") + if not api_key: + raise ValueError("RESEND_API_KEY is not set") + + api_url = app.config.get("RESEND_API_URL") if api_url: resend.api_url = api_url resend.api_key = api_key self._client = resend.Emails - elif app.config.get('MAIL_TYPE') == 'smtp': + elif app.config.get("MAIL_TYPE") == "smtp": from libs.smtp import SMTPClient - if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): - raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') - if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'): - raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS') + + if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): + raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") + if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): + raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") self._client = SMTPClient( - server=app.config.get('SMTP_SERVER'), - port=app.config.get('SMTP_PORT'), - username=app.config.get('SMTP_USERNAME'), - password=app.config.get('SMTP_PASSWORD'), - _from=app.config.get('MAIL_DEFAULT_SEND_FROM'), - use_tls=app.config.get('SMTP_USE_TLS'), - opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS') + server=app.config.get("SMTP_SERVER"), + port=app.config.get("SMTP_PORT"), + username=app.config.get("SMTP_USERNAME"), + password=app.config.get("SMTP_PASSWORD"), + _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), + use_tls=app.config.get("SMTP_USE_TLS"), + opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), ) else: - raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE'))) + raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) else: - logging.warning('MAIL_TYPE is not set') - + logging.warning("MAIL_TYPE is not set") def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: - raise ValueError('Mail client is not initialized') + raise ValueError("Mail client is not initialized") if not from_ and self._default_send_from: from_ = self._default_send_from if not from_: - raise ValueError('mail from is not set') + raise ValueError("mail from is not set") if not to: - raise ValueError('mail to is not set') + raise ValueError("mail to is not set") if not subject: - raise ValueError('mail subject is not set') + raise ValueError("mail subject is not set") if not html: - raise ValueError('mail html is not set') + raise ValueError("mail html is not set") - self._client.send({ - "from": from_, - "to": to, - "subject": subject, - "html": html - }) + self._client.send( + { + "from": from_, + "to": to, + "subject": subject, + "html": html, + } + ) def init_app(app: Flask): diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 23d7768d4d..d5fb162fd8 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -6,18 +6,21 @@ redis_client = redis.Redis() def init_app(app): connection_class = Connection - if app.config.get('REDIS_USE_SSL'): + if app.config.get("REDIS_USE_SSL"): connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool(**{ - 'host': app.config.get('REDIS_HOST'), - 'port': app.config.get('REDIS_PORT'), - 'username': app.config.get('REDIS_USERNAME'), - 'password': app.config.get('REDIS_PASSWORD'), - 'db': app.config.get('REDIS_DB'), - 'encoding': 'utf-8', - 'encoding_errors': 'strict', - 'decode_responses': False - }, connection_class=connection_class) + redis_client.connection_pool = redis.ConnectionPool( + **{ + "host": app.config.get("REDIS_HOST"), + "port": app.config.get("REDIS_PORT"), + "username": app.config.get("REDIS_USERNAME"), + "password": app.config.get("REDIS_PASSWORD"), + "db": app.config.get("REDIS_DB"), + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + }, + connection_class=connection_class, + ) - app.extensions['redis'] = redis_client + app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index f05c10bc08..227c6635f0 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,16 +5,13 @@ from werkzeug.exceptions import HTTPException def init_app(app): - if app.config.get('SENTRY_DSN'): + if app.config.get("SENTRY_DSN"): sentry_sdk.init( - dsn=app.config.get('SENTRY_DSN'), - integrations=[ - FlaskIntegration(), - CeleryIntegration() - ], + dsn=app.config.get("SENTRY_DSN"), + integrations=[FlaskIntegration(), CeleryIntegration()], ignore_errors=[HTTPException, ValueError], - traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0), - profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0), - environment=app.config.get('DEPLOY_ENV'), - release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}" + traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), + profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), + environment=app.config.get("DEPLOY_ENV"), + release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 38db1c6ce1..e6c4352577 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -17,31 +17,19 @@ class Storage: self.storage_runner = None def init_app(self, app: Flask): - storage_type = app.config.get('STORAGE_TYPE') - if storage_type == 's3': - self.storage_runner = S3Storage( - app=app - ) - elif storage_type == 'azure-blob': - self.storage_runner = AzureStorage( - app=app - ) - elif storage_type == 'aliyun-oss': - self.storage_runner = AliyunStorage( - app=app - ) - elif storage_type == 'google-storage': - self.storage_runner = GoogleStorage( - app=app - ) - elif storage_type == 'tencent-cos': - self.storage_runner = TencentStorage( - app=app - ) - elif storage_type == 'oci-storage': - self.storage_runner = OCIStorage( - app=app - ) + storage_type = app.config.get("STORAGE_TYPE") + if storage_type == "s3": + self.storage_runner = S3Storage(app=app) + elif storage_type == "azure-blob": + self.storage_runner = AzureStorage(app=app) + elif storage_type == "aliyun-oss": + self.storage_runner = AliyunStorage(app=app) + elif storage_type == "google-storage": + self.storage_runner = GoogleStorage(app=app) + elif storage_type == "tencent-cos": + self.storage_runner = TencentStorage(app=app) + elif storage_type == "oci-storage": + self.storage_runner = OCIStorage(app=app) else: self.storage_runner = LocalStorage(app=app) diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py index b81a8691f1..b962cedc55 100644 --- a/api/extensions/storage/aliyun_storage.py +++ b/api/extensions/storage/aliyun_storage.py @@ -8,23 +8,22 @@ from extensions.storage.base_storage import BaseStorage class AliyunStorage(BaseStorage): - """Implementation for aliyun storage. - """ + """Implementation for aliyun storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('ALIYUN_OSS_BUCKET_NAME') + self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") oss_auth_method = aliyun_s3.Auth region = None - if app_config.get('ALIYUN_OSS_AUTH_VERSION') == 'v4': + if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": oss_auth_method = aliyun_s3.AuthV4 - region = app_config.get('ALIYUN_OSS_REGION') - oss_auth = oss_auth_method(app_config.get('ALIYUN_OSS_ACCESS_KEY'), app_config.get('ALIYUN_OSS_SECRET_KEY')) + region = app_config.get("ALIYUN_OSS_REGION") + oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) self.client = aliyun_s3.Bucket( oss_auth, - app_config.get('ALIYUN_OSS_ENDPOINT'), + app_config.get("ALIYUN_OSS_ENDPOINT"), self.bucket_name, connect_timeout=30, region=region, diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py index af3e7ef849..ca8cbb9188 100644 --- a/api/extensions/storage/azure_storage.py +++ b/api/extensions/storage/azure_storage.py @@ -9,16 +9,15 @@ from extensions.storage.base_storage import BaseStorage class AzureStorage(BaseStorage): - """Implementation for azure storage. - """ + """Implementation for azure storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('AZURE_BLOB_CONTAINER_NAME') - self.account_url = app_config.get('AZURE_BLOB_ACCOUNT_URL') - self.account_name = app_config.get('AZURE_BLOB_ACCOUNT_NAME') - self.account_key = app_config.get('AZURE_BLOB_ACCOUNT_KEY') + self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") + self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") + self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") + self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") def save(self, filename, data): client = self._sync_client() @@ -39,6 +38,7 @@ class AzureStorage(BaseStorage): blob = client.get_blob_client(container=self.bucket_name, blob=filename) blob_data = blob.download_blob() yield from blob_data.chunks() + return generate(filename) def download(self, filename, target_filepath): @@ -62,17 +62,17 @@ class AzureStorage(BaseStorage): blob_container.delete_blob(filename) def _sync_client(self): - cache_key = 'azure_blob_sas_token_{}_{}'.format(self.account_name, self.account_key) + cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) cache_result = redis_client.get(cache_key) if cache_result is not None: - sas_token = cache_result.decode('utf-8') + sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( account_name=self.account_name, account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1) + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 13d9c34290..c3fe9ec82a 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -1,4 +1,5 @@ """Abstract interface for file storage implementations.""" + from abc import ABC, abstractmethod from collections.abc import Generator @@ -6,8 +7,8 @@ from flask import Flask class BaseStorage(ABC): - """Interface for file storage. - """ + """Interface for file storage.""" + app = None def __init__(self, app: Flask): diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py index ef6cd69039..9ed1fcf0b4 100644 --- a/api/extensions/storage/google_storage.py +++ b/api/extensions/storage/google_storage.py @@ -11,16 +11,16 @@ from extensions.storage.base_storage import BaseStorage class GoogleStorage(BaseStorage): - """Implementation for google storage. - """ + """Implementation for google storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('GOOGLE_STORAGE_BUCKET_NAME') - service_account_json_str = app_config.get('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64') + self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") + service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") # if service_account_json_str is empty, use Application Default Credentials if service_account_json_str: - service_account_json = base64.b64decode(service_account_json_str).decode('utf-8') + service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object service_account_obj = json.loads(service_account_json) self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) @@ -43,9 +43,10 @@ class GoogleStorage(BaseStorage): def generate(filename: str = filename) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - with closing(blob.open(mode='rb')) as blob_stream: + with closing(blob.open(mode="rb")) as blob_stream: while chunk := blob_stream.read(4096): yield chunk + return generate() def download(self, filename, target_filepath): @@ -60,4 +61,4 @@ class GoogleStorage(BaseStorage): def delete(self, filename): bucket = self.client.get_bucket(self.bucket_name) - bucket.delete_blob(filename) \ No newline at end of file + bucket.delete_blob(filename) diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py index 389ef12f82..46ee4bf80f 100644 --- a/api/extensions/storage/local_storage.py +++ b/api/extensions/storage/local_storage.py @@ -8,21 +8,20 @@ from extensions.storage.base_storage import BaseStorage class LocalStorage(BaseStorage): - """Implementation for local storage. - """ + """Implementation for local storage.""" def __init__(self, app: Flask): super().__init__(app) - folder = self.app.config.get('STORAGE_LOCAL_PATH') + folder = self.app.config.get("STORAGE_LOCAL_PATH") if not os.path.isabs(folder): folder = os.path.join(app.root_path, folder) self.folder = folder def save(self, filename, data): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename folder = os.path.dirname(filename) os.makedirs(folder, exist_ok=True) @@ -31,10 +30,10 @@ class LocalStorage(BaseStorage): f.write(data) def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -46,10 +45,10 @@ class LocalStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -61,10 +60,10 @@ class LocalStorage(BaseStorage): return generate() def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if not os.path.exists(filename): raise FileNotFoundError("File not found") @@ -72,17 +71,17 @@ class LocalStorage(BaseStorage): shutil.copyfile(filename, target_filepath) def exists(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename return os.path.exists(filename) def delete(self, filename): - if not self.folder or self.folder.endswith('/'): + if not self.folder or self.folder.endswith("/"): filename = self.folder + filename else: - filename = self.folder + '/' + filename + filename = self.folder + "/" + filename if os.path.exists(filename): os.remove(filename) diff --git a/api/extensions/storage/oci_storage.py b/api/extensions/storage/oci_storage.py index e78d870950..e32fa0a0ae 100644 --- a/api/extensions/storage/oci_storage.py +++ b/api/extensions/storage/oci_storage.py @@ -12,14 +12,14 @@ class OCIStorage(BaseStorage): def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('OCI_BUCKET_NAME') + self.bucket_name = app_config.get("OCI_BUCKET_NAME") self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('OCI_SECRET_KEY'), - aws_access_key_id=app_config.get('OCI_ACCESS_KEY'), - endpoint_url=app_config.get('OCI_ENDPOINT'), - region_name=app_config.get('OCI_REGION') - ) + "s3", + aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), + aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), + endpoint_url=app_config.get("OCI_ENDPOINT"), + region_name=app_config.get("OCI_REGION"), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -27,9 +27,9 @@ class OCIStorage(BaseStorage): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -40,12 +40,13 @@ class OCIStorage(BaseStorage): try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): @@ -61,4 +62,4 @@ class OCIStorage(BaseStorage): return False def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) \ No newline at end of file + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py index 787596fa79..022ce5b14a 100644 --- a/api/extensions/storage/s3_storage.py +++ b/api/extensions/storage/s3_storage.py @@ -10,24 +10,24 @@ from extensions.storage.base_storage import BaseStorage class S3Storage(BaseStorage): - """Implementation for s3 storage. - """ + """Implementation for s3 storage.""" + def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('S3_BUCKET_NAME') - if app_config.get('S3_USE_AWS_MANAGED_IAM'): + self.bucket_name = app_config.get("S3_BUCKET_NAME") + if app_config.get("S3_USE_AWS_MANAGED_IAM"): session = boto3.Session() - self.client = session.client('s3') + self.client = session.client("s3") else: self.client = boto3.client( - 's3', - aws_secret_access_key=app_config.get('S3_SECRET_KEY'), - aws_access_key_id=app_config.get('S3_ACCESS_KEY'), - endpoint_url=app_config.get('S3_ENDPOINT'), - region_name=app_config.get('S3_REGION'), - config=Config(s3={'addressing_style': app_config.get('S3_ADDRESS_STYLE')}) - ) + "s3", + aws_secret_access_key=app_config.get("S3_SECRET_KEY"), + aws_access_key_id=app_config.get("S3_ACCESS_KEY"), + endpoint_url=app_config.get("S3_ENDPOINT"), + region_name=app_config.get("S3_REGION"), + config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}), + ) def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) @@ -35,9 +35,9 @@ class S3Storage(BaseStorage): def load_once(self, filename: str) -> bytes: try: with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise @@ -48,12 +48,13 @@ class S3Storage(BaseStorage): try: with closing(self.client) as client: response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].iter_chunks() + yield from response["Body"].iter_chunks() except ClientError as ex: - if ex.response['Error']['Code'] == 'NoSuchKey': + if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") else: raise + return generate() def download(self, filename, target_filepath): diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py index e2c1ca55e3..1d499cd3bc 100644 --- a/api/extensions/storage/tencent_storage.py +++ b/api/extensions/storage/tencent_storage.py @@ -7,18 +7,17 @@ from extensions.storage.base_storage import BaseStorage class TencentStorage(BaseStorage): - """Implementation for tencent cos storage. - """ + """Implementation for tencent cos storage.""" def __init__(self, app: Flask): super().__init__(app) app_config = self.app.config - self.bucket_name = app_config.get('TENCENT_COS_BUCKET_NAME') + self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") config = CosConfig( - Region=app_config.get('TENCENT_COS_REGION'), - SecretId=app_config.get('TENCENT_COS_SECRET_ID'), - SecretKey=app_config.get('TENCENT_COS_SECRET_KEY'), - Scheme=app_config.get('TENCENT_COS_SCHEME'), + Region=app_config.get("TENCENT_COS_REGION"), + SecretId=app_config.get("TENCENT_COS_SECRET_ID"), + SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), + Scheme=app_config.get("TENCENT_COS_SCHEME"), ) self.client = CosS3Client(config) @@ -26,19 +25,19 @@ class TencentStorage(BaseStorage): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].get_raw_stream().read() + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response['Body'].get_stream(chunk_size=4096) + yield from response["Body"].get_stream(chunk_size=4096) return generate() def download(self, filename, target_filepath): response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - response['Body'].get_stream_to_file(target_filepath) + response["Body"].get_stream_to_file(target_filepath) def exists(self, filename): return self.client.object_exists(Bucket=self.bucket_name, Key=filename) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index c778084475..379dcc6d16 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -5,7 +5,7 @@ from libs.helper import TimestampField annotation_fields = { "id": fields.String, "question": fields.String, - "answer": fields.Raw(attribute='content'), + "answer": fields.Raw(attribute="content"), "hit_count": fields.Integer, "created_at": TimestampField, # 'account': fields.Nested(simple_account_fields, allow_null=True) @@ -21,8 +21,8 @@ annotation_hit_history_fields = { "score": fields.Float, "question": fields.String, "created_at": TimestampField, - "match": fields.String(attribute='annotation_question'), - "response": fields.String(attribute='annotation_content') + "match": fields.String(attribute="annotation_question"), + "response": fields.String(attribute="annotation_content"), } annotation_hit_history_list_fields = { diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index 749e9900de..a85d4a34db 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -8,16 +8,16 @@ class HiddenAPIKey(fields.Raw): api_key = obj.api_key # If the length of the api_key is less than 8 characters, show the first and last characters if len(api_key) <= 8: - return api_key[0] + '******' + api_key[-1] + return api_key[0] + "******" + api_key[-1] # If the api_key is greater than 8 characters, show the first three and the last three characters else: - return api_key[:3] + '******' + api_key[-3:] + return api_key[:3] + "******" + api_key[-3:] api_based_extension_fields = { - 'id': fields.String, - 'name': fields.String, - 'api_endpoint': fields.String, - 'api_key': HiddenAPIKey, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "api_endpoint": fields.String, + "api_key": HiddenAPIKey, + "created_at": TimestampField, } diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 94d804a919..26ed686783 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,159 +1,163 @@ from flask_restful import fields -from libs.helper import TimestampField +from libs.helper import AppIconUrlField, TimestampField app_detail_kernel_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, } related_app_list = { - 'data': fields.List(fields.Nested(app_detail_kernel_fields)), - 'total': fields.Integer, + "data": fields.List(fields.Nested(app_detail_kernel_fields)), + "total": fields.Integer, } model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), - 'text_to_speech': fields.Raw(attribute='text_to_speech_dict'), - 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), - 'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), - 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'dataset_query_variable': fields.String, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw(attribute='agent_mode_dict'), - 'prompt_type': fields.String, - 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), - 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), - 'dataset_configs': fields.Raw(attribute='dataset_configs_dict'), - 'file_upload': fields.Raw(attribute='file_upload_dict'), - 'created_at': TimestampField + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "speech_to_text": fields.Raw(attribute="speech_to_text_dict"), + "text_to_speech": fields.Raw(attribute="text_to_speech_dict"), + "retriever_resource": fields.Raw(attribute="retriever_resource_dict"), + "annotation_reply": fields.Raw(attribute="annotation_reply_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "sensitive_word_avoidance": fields.Raw(attribute="sensitive_word_avoidance_dict"), + "external_data_tools": fields.Raw(attribute="external_data_tools_list"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "dataset_query_variable": fields.String, + "pre_prompt": fields.String, + "agent_mode": fields.Raw(attribute="agent_mode_dict"), + "prompt_type": fields.String, + "chat_prompt_config": fields.Raw(attribute="chat_prompt_config_dict"), + "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), + "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), + "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_at": TimestampField, } app_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'tracing': fields.Raw, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon": fields.String, + "icon_background": fields.String, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "tracing": fields.Raw, + "created_at": TimestampField, } prompt_config_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } model_config_partial_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} app_partial_fields = { - 'id': fields.String, - 'name': fields.String, - 'max_active_requests': fields.Raw(), - 'description': fields.String(attribute='desc_or_prompt'), - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True), - 'created_at': TimestampField, - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "max_active_requests": fields.Raw(), + "description": fields.String(attribute="desc_or_prompt"), + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "created_at": TimestampField, + "tags": fields.List(fields.Nested(tag_fields)), } app_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(app_partial_fields), attribute="items"), } template_fields = { - 'name': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'mode': fields.String, - 'model_config': fields.Nested(model_config_fields), + "name": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "mode": fields.String, + "model_config": fields.Nested(model_config_fields), } template_list_fields = { - 'data': fields.List(fields.Nested(template_fields)), + "data": fields.List(fields.Nested(template_fields)), } site_fields = { - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'app_base_url': fields.String, - 'show_workflow_steps': fields.Boolean, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "default_language": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "app_base_url": fields.String, + "show_workflow_steps": fields.Boolean, } app_detail_fields_with_site = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'mode': fields.String(attribute='mode_compatible_with_agent'), - 'icon': fields.String, - 'icon_background': fields.String, - 'enable_site': fields.Boolean, - 'enable_api': fields.Boolean, - 'model_config': fields.Nested(model_config_fields, attribute='app_model_config', allow_null=True), - 'site': fields.Nested(site_fields), - 'api_base_url': fields.String, - 'created_at': TimestampField, - 'deleted_tools': fields.List(fields.String), + "id": fields.String, + "name": fields.String, + "description": fields.String, + "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "enable_site": fields.Boolean, + "enable_api": fields.Boolean, + "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "site": fields.Nested(site_fields), + "api_base_url": fields.String, + "created_at": TimestampField, + "deleted_tools": fields.List(fields.String), } app_site_fields = { - 'app_id': fields.String, - 'access_token': fields.String(attribute='code'), - 'code': fields.String, - 'title': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'default_language': fields.String, - 'customize_domain': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'customize_token_strategy': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "app_id": fields.String, + "access_token": fields.String(attribute="code"), + "code": fields.String, + "title": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "description": fields.String, + "default_language": fields.String, + "customize_domain": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "customize_token_strategy": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 79ceb02685..3a64801e18 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -6,205 +6,203 @@ from libs.helper import TimestampField class MessageTextField(fields.Raw): def format(self, value): - return value[0]['text'] if value else '' + return value[0]["text"] if value else "" feedback_fields = { - 'rating': fields.String, - 'content': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account': fields.Nested(simple_account_fields, allow_null=True), + "rating": fields.String, + "content": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account": fields.Nested(simple_account_fields, allow_null=True), } annotation_fields = { - 'id': fields.String, - 'question': fields.String, - 'content': fields.String, - 'account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "question": fields.String, + "content": fields.String, + "account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } annotation_hit_history_fields = { - 'annotation_id': fields.String(attribute='id'), - 'annotation_create_account': fields.Nested(simple_account_fields, allow_null=True), - 'created_at': TimestampField + "annotation_id": fields.String(attribute="id"), + "annotation_create_account": fields.Nested(simple_account_fields, allow_null=True), + "created_at": TimestampField, } message_file_fields = { - 'id': fields.String, - 'type': fields.String, - 'url': fields.String, - 'belongs_to': fields.String(default='user'), + "id": fields.String, + "type": fields.String, + "url": fields.String, + "belongs_to": fields.String(default="user"), } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String), + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } message_detail_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'message': fields.Raw, - 'message_tokens': fields.Integer, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'answer_tokens': fields.Integer, - 'provider_response_latency': fields.Float, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'feedbacks': fields.List(fields.Nested(feedback_fields)), - 'workflow_run_id': fields.String, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'metadata': fields.Raw(attribute='message_metadata_dict'), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": fields.Raw, + "message_tokens": fields.Integer, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "answer_tokens": fields.Integer, + "provider_response_latency": fields.Float, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "feedbacks": fields.List(fields.Nested(feedback_fields)), + "workflow_run_id": fields.String, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "metadata": fields.Raw(attribute="message_metadata_dict"), + "status": fields.String, + "error": fields.String, } -feedback_stat_fields = { - 'like': fields.Integer, - 'dislike': fields.Integer -} +feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'model': fields.Raw, - 'user_input_form': fields.Raw, - 'pre_prompt': fields.String, - 'agent_mode': fields.Raw, + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "model": fields.Raw, + "user_input_form": fields.Raw, + "pre_prompt": fields.String, + "agent_mode": fields.Raw, } simple_configs_fields = { - 'prompt_template': fields.String, + "prompt_template": fields.String, } simple_model_config_fields = { - 'model': fields.Raw(attribute='model_dict'), - 'pre_prompt': fields.String, + "model": fields.Raw(attribute="model_dict"), + "pre_prompt": fields.String, } simple_message_detail_fields = { - 'inputs': fields.Raw, - 'query': fields.String, - 'message': MessageTextField, - 'answer': fields.String, + "inputs": fields.Raw, + "query": fields.String, + "message": MessageTextField, + "answer": fields.String, } conversation_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String(), - 'from_account_id': fields.String, - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotation': fields.Nested(annotation_fields, allow_null=True), - 'model_config': fields.Nested(simple_model_config_fields), - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields), - 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String(), + "from_account_id": fields.String, + "read_at": TimestampField, + "created_at": TimestampField, + "annotation": fields.Nested(annotation_fields, allow_null=True), + "model_config": fields.Nested(simple_model_config_fields), + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), + "message": fields.Nested(simple_message_detail_fields, attribute="first_message"), } conversation_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_fields), attribute="items"), } conversation_message_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'model_config': fields.Nested(model_config_fields), - 'message': fields.Nested(message_detail_fields, attribute='first_message'), + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "model_config": fields.Nested(model_config_fields), + "message": fields.Nested(message_detail_fields, attribute="first_message"), } conversation_with_summary_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_end_user_session_id': fields.String, - 'from_account_id': fields.String, - 'name': fields.String, - 'summary': fields.String(attribute='summary_or_query'), - 'read_at': TimestampField, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'model_config': fields.Nested(simple_model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_end_user_session_id": fields.String, + "from_account_id": fields.String, + "name": fields.String, + "summary": fields.String(attribute="summary_or_query"), + "read_at": TimestampField, + "created_at": TimestampField, + "updated_at": TimestampField, + "annotated": fields.Boolean, + "model_config": fields.Nested(simple_model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } conversation_with_summary_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"), } conversation_detail_fields = { - 'id': fields.String, - 'status': fields.String, - 'from_source': fields.String, - 'from_end_user_id': fields.String, - 'from_account_id': fields.String, - 'created_at': TimestampField, - 'annotated': fields.Boolean, - 'introduction': fields.String, - 'model_config': fields.Nested(model_config_fields), - 'message_count': fields.Integer, - 'user_feedback_stats': fields.Nested(feedback_stat_fields), - 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + "id": fields.String, + "status": fields.String, + "from_source": fields.String, + "from_end_user_id": fields.String, + "from_account_id": fields.String, + "created_at": TimestampField, + "annotated": fields.Boolean, + "introduction": fields.String, + "model_config": fields.Nested(model_config_fields), + "message_count": fields.Integer, + "user_feedback_stats": fields.Nested(feedback_stat_fields), + "admin_feedback_stats": fields.Nested(feedback_stat_fields), } simple_conversation_fields = { - 'id': fields.String, - 'name': fields.String, - 'inputs': fields.Raw, - 'status': fields.String, - 'introduction': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "inputs": fields.Raw, + "status": fields.String, + "introduction": fields.String, + "created_at": TimestampField, } conversation_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(simple_conversation_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(simple_conversation_fields)), } conversation_with_model_config_fields = { **simple_conversation_fields, - 'model_config': fields.Raw, + "model_config": fields.Raw, } conversation_with_model_config_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_with_model_config_fields)), } diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 782a848c1a..983e50e73c 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -3,19 +3,19 @@ from flask_restful import fields from libs.helper import TimestampField conversation_variable_fields = { - 'id': fields.String, - 'name': fields.String, - 'value_type': fields.String(attribute='value_type.value'), - 'value': fields.String, - 'description': fields.String, - 'created_at': TimestampField, - 'updated_at': TimestampField, + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.String, + "description": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, } paginated_conversation_variable_fields = { - 'page': fields.Integer, - 'limit': fields.Integer, - 'total': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(conversation_variable_fields), attribute='data'), + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"), } diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 6f3c920c85..071071376f 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -2,64 +2,56 @@ from flask_restful import fields from libs.helper import TimestampField -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'is_bound': fields.Boolean, - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)) + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), } integrate_notion_info_list_fields = { - 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), + "notion_info": fields.List(fields.Nested(integrate_workspace_fields)), } -integrate_icon_fields = { - 'type': fields.String, - 'url': fields.String, - 'emoji': fields.String -} +integrate_icon_fields = {"type": fields.String, "url": fields.String, "emoji": fields.String} integrate_page_fields = { - 'page_name': fields.String, - 'page_id': fields.String, - 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), - 'parent_id': fields.String, - 'type': fields.String + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_fields, allow_null=True), + "parent_id": fields.String, + "type": fields.String, } integrate_workspace_fields = { - 'workspace_name': fields.String, - 'workspace_id': fields.String, - 'workspace_icon': fields.String, - 'pages': fields.List(fields.Nested(integrate_page_fields)), - 'total': fields.Integer + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(integrate_page_fields)), + "total": fields.Integer, } integrate_fields = { - 'id': fields.String, - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'disabled': fields.Boolean, - 'link': fields.String, - 'source_info': fields.Nested(integrate_workspace_fields) + "id": fields.String, + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "disabled": fields.Boolean, + "link": fields.String, + "source_info": fields.Nested(integrate_workspace_fields), } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), -} \ No newline at end of file + "data": fields.List(fields.Nested(integrate_fields)), +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a9f79b5c67..9cf8da7acd 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -3,73 +3,64 @@ from flask_restful import fields from libs.helper import TimestampField dataset_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, + "id": fields.String, + "name": fields.String, + "description": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "created_by": fields.String, + "created_at": TimestampField, } -reranking_model_fields = { - 'reranking_provider_name': fields.String, - 'reranking_model_name': fields.String -} +reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String} -keyword_setting_fields = { - 'keyword_weight': fields.Float -} +keyword_setting_fields = {"keyword_weight": fields.Float} vector_setting_fields = { - 'vector_weight': fields.Float, - 'embedding_model_name': fields.String, - 'embedding_provider_name': fields.String, + "vector_weight": fields.Float, + "embedding_model_name": fields.String, + "embedding_provider_name": fields.String, } weighted_score_fields = { - 'keyword_setting': fields.Nested(keyword_setting_fields), - 'vector_setting': fields.Nested(vector_setting_fields), + "keyword_setting": fields.Nested(keyword_setting_fields), + "vector_setting": fields.Nested(vector_setting_fields), } dataset_retrieval_model_fields = { - 'search_method': fields.String, - 'reranking_enable': fields.Boolean, - 'reranking_mode': fields.String, - 'reranking_model': fields.Nested(reranking_model_fields), - 'weights': fields.Nested(weighted_score_fields, allow_null=True), - 'top_k': fields.Integer, - 'score_threshold_enabled': fields.Boolean, - 'score_threshold': fields.Float + "search_method": fields.String, + "reranking_enable": fields.Boolean, + "reranking_mode": fields.String, + "reranking_model": fields.Nested(reranking_model_fields), + "weights": fields.Nested(weighted_score_fields, allow_null=True), + "top_k": fields.Integer, + "score_threshold_enabled": fields.Boolean, + "score_threshold": fields.Float, } -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String -} +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} dataset_detail_fields = { - 'id': fields.String, - 'name': fields.String, - 'description': fields.String, - 'provider': fields.String, - 'permission': fields.String, - 'data_source_type': fields.String, - 'indexing_technique': fields.String, - 'app_count': fields.Integer, - 'document_count': fields.Integer, - 'word_count': fields.Integer, - 'created_by': fields.String, - 'created_at': TimestampField, - 'updated_by': fields.String, - 'updated_at': TimestampField, - 'embedding_model': fields.String, - 'embedding_model_provider': fields.String, - 'embedding_available': fields.Boolean, - 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields), - 'tags': fields.List(fields.Nested(tag_fields)) + "id": fields.String, + "name": fields.String, + "description": fields.String, + "provider": fields.String, + "permission": fields.String, + "data_source_type": fields.String, + "indexing_technique": fields.String, + "app_count": fields.Integer, + "document_count": fields.Integer, + "word_count": fields.Integer, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, + "embedding_model": fields.String, + "embedding_model_provider": fields.String, + "embedding_available": fields.Boolean, + "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), + "tags": fields.List(fields.Nested(tag_fields)), } dataset_query_detail_fields = { @@ -79,7 +70,5 @@ dataset_query_detail_fields = { "source_app_id": fields.String, "created_by_role": fields.String, "created_by": fields.String, - "created_at": TimestampField + "created_at": TimestampField, } - - diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index e8215255b3..a83ec7bc97 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -4,75 +4,73 @@ from fields.dataset_fields import dataset_fields from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'doc_form': fields.String, + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "doc_form": fields.String, } document_with_segments_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'data_source_type': fields.String, - 'data_source_info': fields.Raw(attribute='data_source_info_dict'), - 'data_source_detail_dict': fields.Raw(attribute='data_source_detail_dict'), - 'dataset_process_rule_id': fields.String, - 'name': fields.String, - 'created_from': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'tokens': fields.Integer, - 'indexing_status': fields.String, - 'error': fields.String, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'archived': fields.Boolean, - 'display_status': fields.String, - 'word_count': fields.Integer, - 'hit_count': fields.Integer, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer + "id": fields.String, + "position": fields.Integer, + "data_source_type": fields.String, + "data_source_info": fields.Raw(attribute="data_source_info_dict"), + "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), + "dataset_process_rule_id": fields.String, + "name": fields.String, + "created_from": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "tokens": fields.Integer, + "indexing_status": fields.String, + "error": fields.String, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "archived": fields.Boolean, + "display_status": fields.String, + "word_count": fields.Integer, + "hit_count": fields.Integer, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } dataset_and_document_fields = { - 'dataset': fields.Nested(dataset_fields), - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String + "dataset": fields.Nested(dataset_fields), + "documents": fields.List(fields.Nested(document_fields)), + "batch": fields.String, } document_status_fields = { - 'id': fields.String, - 'indexing_status': fields.String, - 'processing_started_at': TimestampField, - 'parsing_completed_at': TimestampField, - 'cleaning_completed_at': TimestampField, - 'splitting_completed_at': TimestampField, - 'completed_at': TimestampField, - 'paused_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'completed_segments': fields.Integer, - 'total_segments': fields.Integer, + "id": fields.String, + "indexing_status": fields.String, + "processing_started_at": TimestampField, + "parsing_completed_at": TimestampField, + "cleaning_completed_at": TimestampField, + "splitting_completed_at": TimestampField, + "completed_at": TimestampField, + "paused_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "completed_segments": fields.Integer, + "total_segments": fields.Integer, } -document_status_fields_list = { - 'data': fields.List(fields.Nested(document_status_fields)) -} \ No newline at end of file +document_status_fields_list = {"data": fields.List(fields.Nested(document_status_fields))} diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index ee630c12c2..99e529f9d1 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,8 +1,8 @@ from flask_restful import fields simple_end_user_fields = { - 'id': fields.String, - 'type': fields.String, - 'is_anonymous': fields.Boolean, - 'session_id': fields.String, + "id": fields.String, + "type": fields.String, + "is_anonymous": fields.Boolean, + "session_id": fields.String, } diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc..e5a03ce77e 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -3,17 +3,17 @@ from flask_restful import fields from libs.helper import TimestampField upload_config_fields = { - 'file_size_limit': fields.Integer, - 'batch_count_limit': fields.Integer, - 'image_file_size_limit': fields.Integer, + "file_size_limit": fields.Integer, + "batch_count_limit": fields.Integer, + "image_file_size_limit": fields.Integer, } file_fields = { - 'id': fields.String, - 'name': fields.String, - 'size': fields.Integer, - 'extension': fields.String, - 'mime_type': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, -} \ No newline at end of file + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 541e56a378..f36e80f8d4 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -3,39 +3,39 @@ from flask_restful import fields from libs.helper import TimestampField document_fields = { - 'id': fields.String, - 'data_source_type': fields.String, - 'name': fields.String, - 'doc_type': fields.String, + "id": fields.String, + "data_source_type": fields.String, + "name": fields.String, + "doc_type": fields.String, } segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField, - 'document': fields.Nested(document_fields), + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, + "document": fields.Nested(document_fields), } hit_testing_record_fields = { - 'segment': fields.Nested(segment_fields), - 'score': fields.Float, - 'tsne_position': fields.Raw -} \ No newline at end of file + "segment": fields.Nested(segment_fields), + "score": fields.Float, + "tsne_position": fields.Raw, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 35cc5a6475..9afc1b1a4a 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,25 +1,25 @@ from flask_restful import fields -from libs.helper import TimestampField +from libs.helper import AppIconUrlField, TimestampField app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, } installed_app_fields = { - 'id': fields.String, - 'app': fields.Nested(app_fields), - 'app_owner_tenant_id': fields.String, - 'is_pinned': fields.Boolean, - 'last_used_at': TimestampField, - 'editable': fields.Boolean, - 'uninstallable': fields.Boolean + "id": fields.String, + "app": fields.Nested(app_fields), + "app_owner_tenant_id": fields.String, + "is_pinned": fields.Boolean, + "last_used_at": TimestampField, + "editable": fields.Boolean, + "uninstallable": fields.Boolean, } -installed_app_list_fields = { - 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} \ No newline at end of file +installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))} diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index d061b59c34..1cf8e408d1 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -2,38 +2,32 @@ from flask_restful import fields from libs.helper import TimestampField -simple_account_fields = { - 'id': fields.String, - 'name': fields.String, - 'email': fields.String -} +simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} account_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'is_password_set': fields.Boolean, - 'interface_language': fields.String, - 'interface_theme': fields.String, - 'timezone': fields.String, - 'last_login_at': TimestampField, - 'last_login_ip': fields.String, - 'created_at': TimestampField + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "is_password_set": fields.Boolean, + "interface_language": fields.String, + "interface_theme": fields.String, + "timezone": fields.String, + "last_login_at": TimestampField, + "last_login_ip": fields.String, + "created_at": TimestampField, } account_with_role_fields = { - 'id': fields.String, - 'name': fields.String, - 'avatar': fields.String, - 'email': fields.String, - 'last_login_at': TimestampField, - 'last_active_at': TimestampField, - 'created_at': TimestampField, - 'role': fields.String, - 'status': fields.String, + "id": fields.String, + "name": fields.String, + "avatar": fields.String, + "email": fields.String, + "last_login_at": TimestampField, + "last_active_at": TimestampField, + "created_at": TimestampField, + "role": fields.String, + "status": fields.String, } -account_with_role_list_fields = { - 'accounts': fields.List(fields.Nested(account_with_role_fields)) -} +account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 3116843589..3d2df87afb 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,83 +3,79 @@ from flask_restful import fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'files': fields.List(fields.String) + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "files": fields.List(fields.String), } retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index e41d1a53dd..2dd4cb45be 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -3,31 +3,31 @@ from flask_restful import fields from libs.helper import TimestampField segment_fields = { - 'id': fields.String, - 'position': fields.Integer, - 'document_id': fields.String, - 'content': fields.String, - 'answer': fields.String, - 'word_count': fields.Integer, - 'tokens': fields.Integer, - 'keywords': fields.List(fields.String), - 'index_node_id': fields.String, - 'index_node_hash': fields.String, - 'hit_count': fields.Integer, - 'enabled': fields.Boolean, - 'disabled_at': TimestampField, - 'disabled_by': fields.String, - 'status': fields.String, - 'created_by': fields.String, - 'created_at': TimestampField, - 'indexing_at': TimestampField, - 'completed_at': TimestampField, - 'error': fields.String, - 'stopped_at': TimestampField + "id": fields.String, + "position": fields.Integer, + "document_id": fields.String, + "content": fields.String, + "answer": fields.String, + "word_count": fields.Integer, + "tokens": fields.Integer, + "keywords": fields.List(fields.String), + "index_node_id": fields.String, + "index_node_hash": fields.String, + "hit_count": fields.Integer, + "enabled": fields.Boolean, + "disabled_at": TimestampField, + "disabled_by": fields.String, + "status": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "indexing_at": TimestampField, + "completed_at": TimestampField, + "error": fields.String, + "stopped_at": TimestampField, } segment_list_response = { - 'data': fields.List(fields.Nested(segment_fields)), - 'has_more': fields.Boolean, - 'limit': fields.Integer + "data": fields.List(fields.Nested(segment_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, } diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index f7e030b738..9af4fc57dd 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,8 +1,3 @@ from flask_restful import fields -tag_fields = { - 'id': fields.String, - 'name': fields.String, - 'type': fields.String, - 'binding_count': fields.String -} \ No newline at end of file +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index e230c159fb..a53b546249 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -7,18 +7,18 @@ from libs.helper import TimestampField workflow_app_log_partial_fields = { "id": fields.String, - "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute='workflow_run', allow_null=True), + "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), "created_from": fields.String, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "created_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, } workflow_app_log_pagination_fields = { - 'page': fields.Integer, - 'limit': fields.Integer(attribute='per_page'), - 'total': fields.Integer, - 'has_more': fields.Boolean(attribute='has_next'), - 'data': fields.List(fields.Nested(workflow_app_log_partial_fields), attribute='items') + "page": fields.Integer, + "limit": fields.Integer(attribute="per_page"), + "total": fields.Integer, + "has_more": fields.Boolean(attribute="has_next"), + "data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"), } diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index c1dd0e184a..240b8f2eb0 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -13,43 +13,43 @@ class EnvironmentVariableField(fields.Raw): # Mask secret variables values in environment_variables if isinstance(value, SecretVariable): return { - 'id': value.id, - 'name': value.name, - 'value': encrypter.obfuscated_token(value.value), - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": encrypter.obfuscated_token(value.value), + "value_type": value.value_type.value, } if isinstance(value, Variable): return { - 'id': value.id, - 'name': value.name, - 'value': value.value, - 'value_type': value.value_type.value, + "id": value.id, + "name": value.name, + "value": value.value, + "value_type": value.value_type.value, } if isinstance(value, dict): - value_type = value.get('value_type') + value_type = value.get("value_type") if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: - raise ValueError(f'Unsupported environment variable value type: {value_type}') + raise ValueError(f"Unsupported environment variable value type: {value_type}") return value conversation_variable_fields = { - 'id': fields.String, - 'name': fields.String, - 'value_type': fields.String(attribute='value_type.value'), - 'value': fields.Raw, - 'description': fields.String, + "id": fields.String, + "name": fields.String, + "value_type": fields.String(attribute="value_type.value"), + "value": fields.Raw, + "description": fields.String, } workflow_fields = { - 'id': fields.String, - 'graph': fields.Raw(attribute='graph_dict'), - 'features': fields.Raw(attribute='features_dict'), - 'hash': fields.String(attribute='unique_hash'), - 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), - 'created_at': TimestampField, - 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), - 'updated_at': TimestampField, - 'tool_published': fields.Boolean, - 'environment_variables': fields.List(EnvironmentVariableField()), - 'conversation_variables': fields.List(fields.Nested(conversation_variable_fields)), + "id": fields.String, + "graph": fields.Raw(attribute="graph_dict"), + "features": fields.Raw(attribute="features_dict"), + "hash": fields.String(attribute="unique_hash"), + "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), + "created_at": TimestampField, + "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), + "updated_at": TimestampField, + "tool_published": fields.Boolean, + "environment_variables": fields.List(EnvironmentVariableField()), + "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), } diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 3e798473cd..1413adf719 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -13,7 +13,7 @@ workflow_run_for_log_fields = { "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_for_list_fields = { @@ -24,9 +24,9 @@ workflow_run_for_list_fields = { "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_for_list_fields = { @@ -39,40 +39,40 @@ advanced_chat_workflow_run_for_list_fields = { "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } advanced_chat_workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), } workflow_run_pagination_fields = { - 'limit': fields.Integer(attribute='limit'), - 'has_more': fields.Boolean(attribute='has_more'), - 'data': fields.List(fields.Nested(workflow_run_for_list_fields), attribute='data') + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), + "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), } workflow_run_detail_fields = { "id": fields.String, "sequence_number": fields.Integer, "version": fields.String, - "graph": fields.Raw(attribute='graph_dict'), - "inputs": fields.Raw(attribute='inputs_dict'), + "graph": fields.Raw(attribute="graph_dict"), + "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, - "outputs": fields.Raw(attribute='outputs_dict'), + "outputs": fields.Raw(attribute="outputs_dict"), "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, "total_steps": fields.Integer, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, - "finished_at": TimestampField + "finished_at": TimestampField, } workflow_run_node_execution_fields = { @@ -82,21 +82,21 @@ workflow_run_node_execution_fields = { "node_id": fields.String, "node_type": fields.String, "title": fields.String, - "inputs": fields.Raw(attribute='inputs_dict'), - "process_data": fields.Raw(attribute='process_data_dict'), - "outputs": fields.Raw(attribute='outputs_dict'), + "inputs": fields.Raw(attribute="inputs_dict"), + "process_data": fields.Raw(attribute="process_data_dict"), + "outputs": fields.Raw(attribute="outputs_dict"), "status": fields.String, "error": fields.String, "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute='execution_metadata_dict'), + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), "extras": fields.Raw, "created_at": TimestampField, "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute='created_by_end_user', allow_null=True), - "finished_at": TimestampField + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "finished_at": TimestampField, } workflow_run_node_execution_list_fields = { - 'data': fields.List(fields.Nested(workflow_run_node_execution_fields)), + "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } diff --git a/api/libs/bearer_data_source.py b/api/libs/bearer_data_source.py deleted file mode 100644 index c1aee7b819..0000000000 --- a/api/libs/bearer_data_source.py +++ /dev/null @@ -1,64 +0,0 @@ -# [REVIEW] Implement if Needed? Do we need a new type of data source -from abc import abstractmethod - -import requests -from flask_login import current_user - -from extensions.ext_database import db -from models.source import DataSourceBearerBinding - - -class BearerDataSource: - def __init__(self, api_key: str, api_base_url: str): - self.api_key = api_key - self.api_base_url = api_base_url - - @abstractmethod - def validate_bearer_data_source(self): - """ - Validate the data source - """ - - -class FireCrawlDataSource(BearerDataSource): - def validate_bearer_data_source(self): - TEST_CRAWL_SITE_URL = "https://www.google.com" - FIRECRAWL_API_VERSION = "v0" - - test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - data = { - "url": TEST_CRAWL_SITE_URL, - } - - response = requests.get(test_api_endpoint, headers=headers, json=data) - - return response.json().get("status") == "success" - - def save_credentials(self): - # save data source binding - data_source_binding = DataSourceBearerBinding.query.filter( - db.and_( - DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, - DataSourceBearerBinding.provider == 'firecrawl', - DataSourceBearerBinding.endpoint_url == self.api_base_url, - DataSourceBearerBinding.bearer_key == self.api_key - ) - ).first() - if data_source_binding: - data_source_binding.disabled = False - db.session.commit() - else: - new_data_source_binding = DataSourceBearerBinding( - tenant_id=current_user.current_tenant_id, - provider='firecrawl', - endpoint_url=self.api_base_url, - bearer_key=self.api_key - ) - db.session.add(new_data_source_binding) - db.session.commit() diff --git a/api/libs/exception.py b/api/libs/exception.py index 567062f064..5970269ecd 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): - error_code: str = 'unknown' + error_code: str = "unknown" data: Optional[dict] = None def __init__(self, description=None, response=None): @@ -14,4 +14,4 @@ class BaseHTTPException(HTTPException): "code": self.error_code, "message": self.description, "status": self.code, - } \ No newline at end of file + } diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 677ff0fc5b..179617ac0a 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -10,7 +10,6 @@ from core.errors.error import AppInvokeQuotaExceededError class ExternalApi(Api): - def handle_error(self, e): """Error handler for the API transforms a raised exception into a Flask response, with the appropriate HTTP status code and body. @@ -29,54 +28,57 @@ class ExternalApi(Api): status_code = e.code default_data = { - 'code': re.sub(r'(? self.max_length: - error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' - .format(arg=self.argument, val=value, length=self.max_length)) + error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( + arg=self.argument, val=value, length=self.max_length + ) raise ValueError(error) return value class float_range: - """ Restrict input to an float in a range (inclusive) """ - def __init__(self, low, high, argument='argument'): + """Restrict input to an float in a range (inclusive)""" + + def __init__(self, low, high, argument="argument"): self.low = low self.high = high self.argument = argument @@ -99,15 +113,16 @@ class float_range: def __call__(self, value): value = _get_float(value) if value < self.low or value > self.high: - error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' - .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) + error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( + arg=self.argument, val=value, lo=self.low, hi=self.high + ) raise ValueError(error) return value class datetime_string: - def __init__(self, format, argument='argument'): + def __init__(self, format, argument="argument"): self.format = format self.argument = argument @@ -115,8 +130,9 @@ class datetime_string: try: datetime.strptime(value, self.format) except ValueError: - error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' - .format(arg=self.argument, val=value, format=self.format)) + error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( + arg=self.argument, val=value, format=self.format + ) raise ValueError(error) return value @@ -126,14 +142,14 @@ def _get_float(value): try: return float(value) except (TypeError, ValueError): - raise ValueError('{} is not a valid float'.format(value)) + raise ValueError("{} is not a valid float".format(value)) + def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string - error = ('{timezone_string} is not a valid timezone.' - .format(timezone_string=timezone_string)) + error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) raise ValueError(error) @@ -147,8 +163,8 @@ def generate_string(n): def get_remote_ip(request) -> str: - if request.headers.get('CF-Connecting-IP'): - return request.headers.get('Cf-Connecting-Ip') + if request.headers.get("CF-Connecting-IP"): + return request.headers.get("Cf-Connecting-Ip") elif request.headers.getlist("X-Forwarded-For"): return request.headers.getlist("X-Forwarded-For")[0] else: @@ -156,54 +172,45 @@ def get_remote_ip(request) -> str: def generate_text_hash(text: str) -> str: - hash_text = str(text) + 'None' + hash_text = str(text) + "None" return sha256(hash_text.encode()).hexdigest() def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') + return Response(response=json.dumps(response), status=200, mimetype="application/json") else: + def generate() -> Generator: yield from response - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') + return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") class TokenManager: - @classmethod def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: old_token = cls._get_current_token_for_account(account.id, token_type) if old_token: if isinstance(old_token, bytes): - old_token = old_token.decode('utf-8') + old_token = old_token.decode("utf-8") cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = { - 'account_id': account.id, - 'email': account.email, - 'token_type': token_type - } + token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} if additional_data: token_data.update(additional_data) - expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] + expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] token_key = cls._get_token_key(token, token_type) - redis_client.setex( - token_key, - expiry_hours * 60 * 60, - json.dumps(token_data) - ) + redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @classmethod def _get_token_key(cls, token: str, token_type: str) -> str: - return f'{token_type}:token:{token}' + return f"{token_type}:token:{token}" @classmethod def revoke_token(cls, token: str, token_type: str): @@ -233,7 +240,7 @@ class TokenManager: @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: - return f'{token_type}:account:{account_id}' + return f"{token_type}:account:{account_id}" class RateLimiter: @@ -250,7 +257,7 @@ class RateLimiter: current_time = int(time.time()) window_start_time = current_time - self.time_window - redis_client.zremrangebyscore(key, '-inf', window_start_time) + redis_client.zremrangebyscore(key, "-inf", window_start_time) attempts = redis_client.zcard(key) if attempts and int(attempts) >= self.max_attempts: diff --git a/api/libs/infinite_scroll_pagination.py b/api/libs/infinite_scroll_pagination.py index a1cb7b78fc..133ccb1883 100644 --- a/api/libs/infinite_scroll_pagination.py +++ b/api/libs/infinite_scroll_pagination.py @@ -1,4 +1,3 @@ - class InfiniteScrollPagination: def __init__(self, data, limit, has_more): self.data = data diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 2cf023a399..41d6905899 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,13 +10,13 @@ def parse_json_markdown(json_string: str) -> dict: end_index = json_string.find("```", start_index + len("```json")) if start_index != -1 and end_index != -1: - extracted_content = json_string[start_index + len("```json"):end_index].strip() + extracted_content = json_string[start_index + len("```json") : end_index].strip() # Parse the JSON string into a Python dictionary parsed = json.loads(extracted_content) elif start_index != -1 and end_index == -1 and json_string.endswith("``"): end_index = json_string.find("``", start_index + len("```json")) - extracted_content = json_string[start_index + len("```json"):end_index].strip() + extracted_content = json_string[start_index + len("```json") : end_index].strip() # Parse the JSON string into a Python dictionary parsed = json.loads(extracted_content) @@ -37,7 +37,6 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: for key in expected_keys: if key not in json_obj: raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " - f"to be present, but got {json_obj}" + f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/login.py b/api/libs/login.py index 14085fe603..7f05eb8404 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -51,27 +51,29 @@ def login_required(func): @wraps(func) def decorated_view(*args, **kwargs): - auth_header = request.headers.get('Authorization') - admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False') - if admin_api_key_enable.lower() == 'true': + auth_header = request.headers.get("Authorization") + admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") + if admin_api_key_enable.lower() == "true": if auth_header: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - admin_api_key = os.getenv('ADMIN_API_KEY') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + admin_api_key = os.getenv("ADMIN_API_KEY") if admin_api_key: - if os.getenv('ADMIN_API_KEY') == auth_token: - workspace_id = request.headers.get('X-WORKSPACE-ID') + if os.getenv("ADMIN_API_KEY") == auth_token: + workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == workspace_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role == 'owner') \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == workspace_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role == "owner") .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() diff --git a/api/libs/oauth.py b/api/libs/oauth.py index dacdee0bc1..d8ce1a1e66 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -35,31 +35,31 @@ class OAuth: class GitHubOAuth(OAuth): - _AUTH_URL = 'https://github.com/login/oauth/authorize' - _TOKEN_URL = 'https://github.com/login/oauth/access_token' - _USER_INFO_URL = 'https://api.github.com/user' - _EMAIL_INFO_URL = 'https://api.github.com/user/emails' + _AUTH_URL = "https://github.com/login/oauth/authorize" + _TOKEN_URL = "https://github.com/login/oauth/access_token" + _USER_INFO_URL = "https://api.github.com/user" + _EMAIL_INFO_URL = "https://api.github.com/user/emails" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'redirect_uri': self.redirect_uri, - 'scope': 'user:email' # Request only basic user information + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": "user:email", # Request only basic user information } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'redirect_uri': self.redirect_uri + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, } - headers = {'Accept': 'application/json'} + headers = {"Accept": "application/json"} response = requests.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in GitHub OAuth: {response_json}") @@ -67,55 +67,51 @@ class GitHubOAuth(OAuth): return access_token def get_raw_user_info(self, token: str): - headers = {'Authorization': f"token {token}"} + headers = {"Authorization": f"token {token}"} response = requests.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() user_info = response.json() email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email['primary'] == True), None) + primary_email = next((email for email in email_info if email["primary"] == True), None) - return {**user_info, 'email': primary_email['email']} + return {**user_info, "email": primary_email["email"]} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - email = raw_info.get('email') + email = raw_info.get("email") if not email: email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" - return OAuthUserInfo( - id=str(raw_info['id']), - name=raw_info['name'], - email=email - ) + return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) class GoogleOAuth(OAuth): - _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' - _TOKEN_URL = 'https://oauth2.googleapis.com/token' - _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' + _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _TOKEN_URL = "https://oauth2.googleapis.com/token" + _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'scope': 'openid email' + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": "openid email", } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, } - headers = {'Accept': 'application/json'} + headers = {"Accept": "application/json"} response = requests.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in Google OAuth: {response_json}") @@ -123,16 +119,10 @@ class GoogleOAuth(OAuth): return access_token def get_raw_user_info(self, token: str): - headers = {'Authorization': f"Bearer {token}"} + headers = {"Authorization": f"Bearer {token}"} response = requests.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo( - id=str(raw_info['sub']), - name=None, - email=raw_info['email'] - ) - - + return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 358858ceb1..6da1a6d39b 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -21,53 +21,49 @@ class OAuthDataSource: class NotionOAuth(OAuthDataSource): - _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' - _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' + _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" + _TOKEN_URL = "https://api.notion.com/v1/oauth/token" _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'owner': 'user' + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "owner": "user", } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): - data = { - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri - } - headers = {'Accept': 'application/json'} + data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} + headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in Notion OAuth: {response_json}") - workspace_name = response_json.get('workspace_name') - workspace_icon = response_json.get('workspace_icon') - workspace_id = response_json.get('workspace_id') + workspace_name = response_json.get("workspace_name") + workspace_icon = response_json.get("workspace_icon") + workspace_id = response_json.get("workspace_id") # get all authorized pages pages = self.get_authorized_pages(access_token) source_info = { - 'workspace_name': workspace_name, - 'workspace_icon': workspace_icon, - 'workspace_id': workspace_id, - 'pages': pages, - 'total': len(pages) + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), } # save data source binding data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.access_token == access_token + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) ).first() if data_source_binding: @@ -79,7 +75,7 @@ class NotionOAuth(OAuthDataSource): tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, - provider='notion' + provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() @@ -91,18 +87,18 @@ class NotionOAuth(OAuthDataSource): # get all authorized pages pages = self.get_authorized_pages(access_token) source_info = { - 'workspace_name': workspace_name, - 'workspace_icon': workspace_icon, - 'workspace_id': workspace_id, - 'pages': pages, - 'total': len(pages) + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), } # save data source binding data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.access_token == access_token + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) ).first() if data_source_binding: @@ -114,7 +110,7 @@ class NotionOAuth(OAuthDataSource): tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, - provider='notion' + provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() @@ -124,9 +120,9 @@ class NotionOAuth(OAuthDataSource): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False + DataSourceOauthBinding.disabled == False, ) ).first() if data_source_binding: @@ -134,17 +130,17 @@ class NotionOAuth(OAuthDataSource): pages = self.get_authorized_pages(data_source_binding.access_token) source_info = data_source_binding.source_info new_source_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, - 'total': len(pages) + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, + "total": len(pages), } data_source_binding.source_info = new_source_info data_source_binding.disabled = False db.session.commit() else: - raise ValueError('Data source binding not found') + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str): pages = [] @@ -152,143 +148,121 @@ class NotionOAuth(OAuthDataSource): database_results = self.notion_database_search(access_token) # get page detail for page_result in page_results: - page_id = page_result['id'] - page_name = 'Untitled' - for key in page_result['properties']: - if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']: - title_list = page_result['properties'][key]['title'] - if len(title_list) > 0 and 'plain_text' in title_list[0]: - page_name = title_list[0]['plain_text'] - page_icon = page_result['icon'] + page_id = page_result["id"] + page_name = "Untitled" + for key in page_result["properties"]: + if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: + title_list = page_result["properties"][key]["title"] + if len(title_list) > 0 and "plain_text" in title_list[0]: + page_name = title_list[0]["plain_text"] + page_icon = page_result["icon"] if page_icon: - icon_type = page_icon['type'] - if icon_type == 'external' or icon_type == 'file': - url = page_icon[icon_type]['url'] - icon = { - 'type': 'url', - 'url': url if url.startswith('http') else f'https://www.notion.so{url}' - } + icon_type = page_icon["type"] + if icon_type == "external" or icon_type == "file": + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: - icon = { - 'type': 'emoji', - 'emoji': page_icon[icon_type] - } + icon = {"type": "emoji", "emoji": page_icon[icon_type]} else: icon = None - parent = page_result['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = page_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) - elif parent_type == 'workspace': - parent_id = 'root' + elif parent_type == "workspace": + parent_id = "root" else: parent_id = parent[parent_type] page = { - 'page_id': page_id, - 'page_name': page_name, - 'page_icon': icon, - 'parent_id': parent_id, - 'type': 'page' + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "page", } pages.append(page) # get database detail for database_result in database_results: - page_id = database_result['id'] - if len(database_result['title']) > 0: - page_name = database_result['title'][0]['plain_text'] + page_id = database_result["id"] + if len(database_result["title"]) > 0: + page_name = database_result["title"][0]["plain_text"] else: - page_name = 'Untitled' - page_icon = database_result['icon'] + page_name = "Untitled" + page_icon = database_result["icon"] if page_icon: - icon_type = page_icon['type'] - if icon_type == 'external' or icon_type == 'file': - url = page_icon[icon_type]['url'] - icon = { - 'type': 'url', - 'url': url if url.startswith('http') else f'https://www.notion.so{url}' - } + icon_type = page_icon["type"] + if icon_type == "external" or icon_type == "file": + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: - icon = { - 'type': icon_type, - icon_type: page_icon[icon_type] - } + icon = {"type": icon_type, icon_type: page_icon[icon_type]} else: icon = None - parent = database_result['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = database_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) - elif parent_type == 'workspace': - parent_id = 'root' + elif parent_type == "workspace": + parent_id = "root" else: parent_id = parent[parent_type] page = { - 'page_id': page_id, - 'page_name': page_name, - 'page_icon': icon, - 'parent_id': parent_id, - 'type': 'database' + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "database", } pages.append(page) return pages def notion_page_search(self, access_token: str): - data = { - 'filter': { - "value": "page", - "property": "object" - } - } + data = {"filter": {"value": "page", "property": "object"}} headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - results = response_json.get('results', []) + results = response_json.get("results", []) return results def notion_block_parent_page_id(self, access_token: str, block_id: str): headers = { - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } - response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) + response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() - parent = response_json['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = response_json["parent"] + parent_type = parent["type"] + if parent_type == "block_id": return self.notion_block_parent_page_id(access_token, parent[parent_type]) return parent[parent_type] def notion_workspace_name(self, access_token: str): headers = { - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.get(url=self._NOTION_BOT_USER, headers=headers) response_json = response.json() - if 'object' in response_json and response_json['object'] == 'user': - user_type = response_json['type'] + if "object" in response_json and response_json["object"] == "user": + user_type = response_json["type"] user_info = response_json[user_type] - if 'workspace_name' in user_info: - return user_info['workspace_name'] - return 'workspace' + if "workspace_name" in user_info: + return user_info["workspace_name"] + return "workspace" def notion_database_search(self, access_token: str): - data = { - 'filter': { - "value": "database", - "property": "object" - } - } + data = {"filter": {"value": "database", "property": "object"}} headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - results = response_json.get('results', []) + results = response_json.get("results", []) return results diff --git a/api/libs/passport.py b/api/libs/passport.py index 34bdc55997..8df4f529bc 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -9,14 +9,14 @@ class PassportService: self.sk = dify_config.SECRET_KEY def issue(self, payload): - return jwt.encode(payload, self.sk, algorithm='HS256') + return jwt.encode(payload, self.sk, algorithm="HS256") def verify(self, token): try: - return jwt.decode(token, self.sk, algorithms=['HS256']) + return jwt.decode(token, self.sk, algorithms=["HS256"]) except jwt.exceptions.InvalidSignatureError: - raise Unauthorized('Invalid token signature.') + raise Unauthorized("Invalid token signature.") except jwt.exceptions.DecodeError: - raise Unauthorized('Invalid token.') + raise Unauthorized("Invalid token.") except jwt.exceptions.ExpiredSignatureError: - raise Unauthorized('Token has expired.') + raise Unauthorized("Token has expired.") diff --git a/api/libs/password.py b/api/libs/password.py index cdd1d69dbf..cfcc0db22d 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -5,6 +5,7 @@ import re password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" + def valid_password(password): # Define a regex pattern for password rules pattern = password_pattern @@ -12,11 +13,11 @@ def valid_password(password): if re.match(pattern, password) is not None: return password - raise ValueError('Not a valid password.') + raise ValueError("Not a valid password.") def hash_password(password_str, salt_byte): - dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) + dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) return binascii.hexlify(dk) diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 9f29c58812..a578bf3e56 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -48,7 +48,7 @@ def encrypt(text, public_key): def get_decrypt_decoding(tenant_id): filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" - cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) + cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) private_key = redis_client.get(cache_key) if not private_key: try: @@ -66,12 +66,12 @@ def get_decrypt_decoding(tenant_id): def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): if encrypted_text.startswith(prefix_hybrid): - encrypted_text = encrypted_text[len(prefix_hybrid):] + encrypted_text = encrypted_text[len(prefix_hybrid) :] - enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] - nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] - tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] - ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] + enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] + nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] + tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] + ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] aes_key = cipher_rsa.decrypt(enc_aes_key) diff --git a/api/libs/smtp.py b/api/libs/smtp.py index bf3a1a92e9..bd7de7dd68 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -5,7 +5,9 @@ from email.mime.text import MIMEText class SMTPClient: - def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): + def __init__( + self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False + ): self.server = server self.port = port self._from = _from @@ -25,17 +27,17 @@ class SMTPClient: smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: smtp = smtplib.SMTP(self.server, self.port, timeout=10) - + if self.username and self.password: smtp.login(self.username, self.password) msg = MIMEMultipart() - msg['Subject'] = mail['subject'] - msg['From'] = self._from - msg['To'] = mail['to'] - msg.attach(MIMEText(mail['html'], 'html')) + msg["Subject"] = mail["subject"] + msg["From"] = self._from + msg["To"] = mail["to"] + msg.attach(MIMEText(mail["html"], "html")) - smtp.sendmail(self._from, mail['to'], msg.as_string()) + smtp.sendmail(self._from, mail["to"], msg.as_string()) except smtplib.SMTPException as e: logging.error(f"SMTP error occurred: {str(e)}") raise diff --git a/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py new file mode 100644 index 0000000000..d814666eef --- /dev/null +++ b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py @@ -0,0 +1,39 @@ +"""app and site icon type + +Revision ID: a6be81136580 +Revises: 8782057ff0dc +Create Date: 2024-08-15 10:01:24.697888 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'a6be81136580' +down_revision = '8782057ff0dc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py new file mode 100644 index 0000000000..3dc7fed818 --- /dev/null +++ b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py @@ -0,0 +1,28 @@ +"""rename workflow__conversation_variables to workflow_conversation_variables + +Revision ID: 2dbe42621d96 +Revises: a6be81136580 +Create Date: 2024-08-20 04:55:38.160010 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '2dbe42621d96' +down_revision = 'a6be81136580' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 5426d3bc83..7301629771 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -51,6 +51,10 @@ class AppMode(Enum): raise ValueError(f'invalid mode value {value}') +class IconType(Enum): + IMAGE = "image" + EMOJI = "emoji" + class App(db.Model): __tablename__ = 'apps' __table_args__ = ( @@ -63,6 +67,7 @@ class App(db.Model): name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(StringUUID, nullable=True) @@ -1087,6 +1092,7 @@ class Site(db.Model): id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) description = db.Column(db.Text) @@ -1316,9 +1322,7 @@ class MessageAgentThought(db.Model): } except Exception as e: if self.observation: - return { - tool: self.observation for tool in tools - } + return dict.fromkeys(tools, self.observation) class DatasetRetrieverResource(db.Model): diff --git a/api/models/workflow.py b/api/models/workflow.py index 759e07c715..cdd5e1992d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,5 +1,6 @@ import json from collections.abc import Mapping, Sequence +from datetime import datetime from enum import Enum from typing import Any, Optional, Union @@ -110,19 +111,32 @@ class Workflow(db.Model): db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - features = db.Column(db.Text) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(StringUUID) - updated_at = db.Column(db.DateTime) - _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') - _conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + app_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + version: Mapped[str] = db.Column(db.String(255), nullable=False) + graph: Mapped[str] = db.Column(db.Text) + features: Mapped[str] = db.Column(db.Text) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by: Mapped[str] = db.Column(StringUUID) + updated_at: Mapped[datetime] = db.Column(db.DateTime) + _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') + _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + + def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, + features: str, created_by: str, environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable]): + self.tenant_id = tenant_id + self.app_id = app_id + self.type = type + self.version = version + self.graph = graph + self.features = features + self.created_by = created_by + self.environment_variables = environment_variables or [] + self.conversation_variables = conversation_variables or [] @property def created_by_account(self): @@ -724,7 +738,7 @@ class WorkflowAppLog(db.Model): class ConversationVariable(db.Model): - __tablename__ = 'workflow__conversation_variables' + __tablename__ = 'workflow_conversation_variables' id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) diff --git a/api/poetry.lock b/api/poetry.lock index 358f9f8510..9bfeec30d7 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -7370,29 +7370,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.5.7" +version = "0.6.1" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, - {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, - {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, - {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, - {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, - {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, - {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, + {file = "ruff-0.6.1-py3-none-linux_armv6l.whl", hash = "sha256:b4bb7de6a24169dc023f992718a9417380301b0c2da0fe85919f47264fb8add9"}, + {file = "ruff-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:45efaae53b360c81043e311cdec8a7696420b3d3e8935202c2846e7a97d4edae"}, + {file = "ruff-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bc60c7d71b732c8fa73cf995efc0c836a2fd8b9810e115be8babb24ae87e0850"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c7477c3b9da822e2db0b4e0b59e61b8a23e87886e727b327e7dcaf06213c5cf"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3a0af7ab3f86e3dc9f157a928e08e26c4b40707d0612b01cd577cc84b8905cc9"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:392688dbb50fecf1bf7126731c90c11a9df1c3a4cdc3f481b53e851da5634fa5"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5278d3e095ccc8c30430bcc9bc550f778790acc211865520f3041910a28d0024"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe6d5f65d6f276ee7a0fc50a0cecaccb362d30ef98a110f99cac1c7872df2f18"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e0dd11e2ae553ee5c92a81731d88a9883af8db7408db47fc81887c1f8b672e"}, + {file = "ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d812615525a34ecfc07fd93f906ef5b93656be01dfae9a819e31caa6cfe758a1"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faaa4060f4064c3b7aaaa27328080c932fa142786f8142aff095b42b6a2eb631"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:99d7ae0df47c62729d58765c593ea54c2546d5de213f2af2a19442d50a10cec9"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9eb18dfd7b613eec000e3738b3f0e4398bf0153cb80bfa3e351b3c1c2f6d7b15"}, + {file = "ruff-0.6.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c62bc04c6723a81e25e71715aa59489f15034d69bf641df88cb38bdc32fd1dbb"}, + {file = "ruff-0.6.1-py3-none-win32.whl", hash = "sha256:9fb4c4e8b83f19c9477a8745e56d2eeef07a7ff50b68a6998f7d9e2e3887bdc4"}, + {file = "ruff-0.6.1-py3-none-win_amd64.whl", hash = "sha256:c2ebfc8f51ef4aca05dad4552bbcf6fe8d1f75b2f6af546cc47cc1c1ca916b5b"}, + {file = "ruff-0.6.1-py3-none-win_arm64.whl", hash = "sha256:3bc81074971b0ffad1bd0c52284b22411f02a11a012082a76ac6da153536e014"}, + {file = "ruff-0.6.1.tar.gz", hash = "sha256:af3ffd8c6563acb8848d33cd19a69b9bfe943667f0419ca083f8ebe4224a3436"}, ] [[package]] @@ -9584,4 +9584,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf" +content-hash = "a74c7b6a72145d5074aa84581df6e543ea422810caf0ba1561cd2d35497243ca" diff --git a/api/pyproject.toml b/api/pyproject.toml index 82e2aaeb2b..82ccd0b202 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -69,7 +69,16 @@ ignore = [ ] [tool.ruff.format] -quote-style = "single" +exclude = [ + "core/**/*.py", + "controllers/**/*.py", + "models/**/*.py", + "migrations/**/*", + "services/**/*.py", + "tasks/**/*.py", + "tests/**/*.py", + "configs/**/*.py", +] [tool.pytest_env] OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" @@ -147,6 +156,7 @@ markdown = "~3.5.1" novita-client = "^0.5.6" numpy = "~1.26.4" openai = "~1.29.0" +openpyxl = "~3.1.5" oss2 = "2.18.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" @@ -164,7 +174,6 @@ readabilipy = "0.2.0" redis = { version = "~5.0.3", extras = ["hiredis"] } replicate = "~0.22.0" resend = "~0.7.0" -safetensors = "~0.4.3" scikit-learn = "^1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" @@ -178,10 +187,16 @@ werkzeug = "~3.0.1" xinference-client = "0.13.3" yarl = "~1.9.4" zhipuai = "1.0.7" -rank-bm25 = "~0.2.2" -openpyxl = "^3.1.5" +# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. + +############################################################ +# Related transparent dependencies with pinned verion +# required by main implementations +############################################################ +[tool.poetry.group.indriect.dependencies] kaleido = "0.2.1" -elasticsearch = "8.14.0" +rank-bm25 = "~0.2.2" +safetensors = "~0.4.3" ############################################################ # Tool dependencies required by tool implementations @@ -189,6 +204,7 @@ elasticsearch = "8.14.0" [tool.poetry.group.tool.dependencies] arxiv = "2.1.0" +cloudscraper = "1.2.71" matplotlib = "~3.8.2" newspaper3k = "0.2.8" duckduckgo-search = "^6.2.6" @@ -200,26 +216,25 @@ twilio = "~9.0.4" vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } wikipedia = "1.4.0" yfinance = "~0.2.40" -cloudscraper = "1.2.71" ############################################################ # VDB dependencies required by vector store clients ############################################################ [tool.poetry.group.vdb.dependencies] +alibabacloud_gpdb20160503 = "~3.8.0" +alibabacloud_tea_openapi = "~0.3.9" chromadb = "0.5.1" +clickhouse-connect = "~0.7.16" +elasticsearch = "8.14.0" oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" -pymysql = "1.1.1" tcvectordb = "1.3.2" tidb-vector = "0.0.9" qdrant-client = "1.7.3" weaviate-client = "~3.21.0" -alibabacloud_gpdb20160503 = "~3.8.0" -alibabacloud_tea_openapi = "~0.3.9" -clickhouse-connect = "~0.7.16" ############################################################ # Dev dependencies for running tests @@ -243,5 +258,5 @@ pytest-mock = "~3.14.0" optional = true [tool.poetry.group.lint.dependencies] -ruff = "~0.5.7" dotenv-linter = "~0.5.0" +ruff = "~0.6.1" diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index ccc1062266..67d0706828 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -11,27 +11,32 @@ from extensions.ext_database import db from models.dataset import Embedding -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_embedding_cache_task(): - click.echo(click.style('Start clean embedding cache.', fg='green')) + click.echo(click.style("Start clean embedding cache.", fg="green")) clean_days = int(dify_config.CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: try: - embedding_ids = db.session.query(Embedding.id).filter(Embedding.created_at < thirty_days_ago) \ - .order_by(Embedding.created_at.desc()).limit(100).all() + embedding_ids = ( + db.session.query(Embedding.id) + .filter(Embedding.created_at < thirty_days_ago) + .order_by(Embedding.created_at.desc()) + .limit(100) + .all() + ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] except NotFound: break if embedding_ids: for embedding_id in embedding_ids: - db.session.execute(text( - "DELETE FROM embeddings WHERE id = :embedding_id" - ), {'embedding_id': embedding_id}) + db.session.execute( + text("DELETE FROM embeddings WHERE id = :embedding_id"), {"embedding_id": embedding_id} + ) db.session.commit() else: break end_at = time.perf_counter() - click.echo(click.style('Cleaned embedding cache from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index b2b2f82b78..3d799bfd4e 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -12,9 +12,9 @@ from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, Document -@app.celery.task(queue='dataset') +@app.celery.task(queue="dataset") def clean_unused_datasets_task(): - click.echo(click.style('Start clean unused datasets indexes.', fg='green')) + click.echo(click.style("Start clean unused datasets indexes.", fg="green")) clean_days = dify_config.CLEAN_DAY_SETTING start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) @@ -22,40 +22,44 @@ def clean_unused_datasets_task(): while True: try: # Subquery for counting new documents - document_subquery_new = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at > thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Subquery for counting old documents - document_subquery_old = db.session.query( - Document.dataset_id, - func.count(Document.id).label('document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.updated_at < thirty_days_ago - ).group_by(Document.dataset_id).subquery() + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < thirty_days_ago, + ) + .group_by(Document.dataset_id) + .subquery() + ) # Main query with join and filter - datasets = (db.session.query(Dataset) - .outerjoin( - document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id - ).outerjoin( - document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id - ).filter( - Dataset.created_at < thirty_days_ago, - func.coalesce(document_subquery_new.c.document_count, 0) == 0, - func.coalesce(document_subquery_old.c.document_count, 0) > 0 - ).order_by( - Dataset.created_at.desc() - ).paginate(page=page, per_page=50)) + datasets = ( + db.session.query(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < thirty_days_ago, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) except NotFound: break @@ -63,10 +67,11 @@ def clean_unused_datasets_task(): break page += 1 for dataset in datasets: - dataset_query = db.session.query(DatasetQuery).filter( - DatasetQuery.created_at > thirty_days_ago, - DatasetQuery.dataset_id == dataset.id - ).all() + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id) + .all() + ) if not dataset_query or len(dataset_query) == 0: try: # remove index @@ -74,17 +79,14 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # update document - update_params = { - Document.enabled: False - } + update_params = {Document.enabled: False} Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() - click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id), - fg='green')) + click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: click.echo( - click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) end_at = time.perf_counter() - click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green')) + click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 838585c575..737def3366 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -82,6 +82,7 @@ class AppDslService: # get app basic info name = args.get("name") if args.get("name") else app_data.get('name') description = args.get("description") if args.get("description") else app_data.get('description', '') + icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type') icon = args.get("icon") if args.get("icon") else app_data.get('icon') icon_background = args.get("icon_background") if args.get("icon_background") \ else app_data.get('icon_background') @@ -96,6 +97,7 @@ class AppDslService: account=account, name=name, description=description, + icon_type=icon_type, icon=icon, icon_background=icon_background ) @@ -107,6 +109,7 @@ class AppDslService: account=account, name=name, description=description, + icon_type=icon_type, icon=icon, icon_background=icon_background ) @@ -165,8 +168,8 @@ class AppDslService: "app": { "name": app_model.name, "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background, + "icon": '🤖' if app_model.icon_type == 'image' else app_model.icon, + "icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background, "description": app_model.description } } @@ -207,6 +210,7 @@ class AppDslService: account: Account, name: str, description: str, + icon_type: str, icon: str, icon_background: str) -> App: """ @@ -218,6 +222,7 @@ class AppDslService: :param account: Account instance :param name: app name :param description: app description + :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background """ @@ -231,6 +236,7 @@ class AppDslService: account=account, name=name, description=description, + icon_type=icon_type, icon=icon, icon_background=icon_background ) @@ -285,6 +291,8 @@ class AppDslService: # sync draft workflow environment_variables_list = workflow_data.get('environment_variables') or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] + conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=workflow_data.get('graph', {}), @@ -292,6 +300,7 @@ class AppDslService: unique_hash=unique_hash, account=account, environment_variables=environment_variables, + conversation_variables=conversation_variables, ) return draft_workflow @@ -304,6 +313,7 @@ class AppDslService: account: Account, name: str, description: str, + icon_type: str, icon: str, icon_background: str) -> App: """ @@ -328,6 +338,7 @@ class AppDslService: account=account, name=name, description=description, + icon_type=icon_type, icon=icon, icon_background=icon_background ) @@ -355,6 +366,7 @@ class AppDslService: account: Account, name: str, description: str, + icon_type: str, icon: str, icon_background: str) -> App: """ @@ -365,6 +377,7 @@ class AppDslService: :param account: Account instance :param name: app name :param description: app description + :param icon_type: app icon type, "emoji" or "image" :param icon: app icon :param icon_background: app icon background """ @@ -373,6 +386,7 @@ class AppDslService: mode=app_mode.value, name=name, description=description, + icon_type=icon_type, icon=icon, icon_background=icon_background, enable_site=True, diff --git a/api/services/app_service.py b/api/services/app_service.py index e433bb59bb..93f7169c1f 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -111,6 +111,12 @@ class AppService: 'completion_params': {} } else: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id, + model_type=ModelType.LLM + ) + default_model_config['model']['provider'] = provider + default_model_config['model']['name'] = model default_model_dict = default_model_config['model'] default_model_config['model'] = json.dumps(default_model_dict) @@ -119,6 +125,7 @@ class AppService: app.name = args['name'] app.description = args.get('description', '') app.mode = args['mode'] + app.icon_type = args.get('icon_type', 'emoji') app.icon = args['icon'] app.icon_background = args['icon_background'] app.tenant_id = tenant_id @@ -189,13 +196,14 @@ class AppService: """ Modified App class """ + def __init__(self, app): self.__dict__.update(app.__dict__) @property def app_model_config(self): return model_config - + app = ModifiedApp(app) return app @@ -210,6 +218,7 @@ class AppService: app.name = args.get('name') app.description = args.get('description', '') app.max_active_requests = args.get('max_active_requests') + app.icon_type = args.get('icon_type', 'emoji') app.icon = args.get('icon') app.icon_background = args.get('icon_background') app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 82ee10ee78..053d704e17 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,7 @@ +from datetime import datetime, timezone from typing import Optional, Union -from sqlalchemy import or_ +from sqlalchemy import asc, desc, or_ from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -18,7 +19,8 @@ class ConversationService: last_id: Optional[str], limit: int, invoke_from: InvokeFrom, include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + exclude_ids: Optional[list] = None, + sort_by: str = '-updated_at') -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -37,28 +39,28 @@ class ConversationService: if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) - if last_id: - last_conversation = base_query.filter( - Conversation.id == last_id, - ).first() + # define sort fields and directions + sort_field, sort_direction = cls._get_sort_params(sort_by) + if last_id: + last_conversation = base_query.filter(Conversation.id == last_id).first() if not last_conversation: raise LastConversationNotExistsError() - conversations = base_query.filter( - Conversation.created_at < last_conversation.created_at, - Conversation.id != last_conversation.id - ).order_by(Conversation.created_at.desc()).limit(limit).all() - else: - conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all() + # build filters based on sorting + filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) + base_query = base_query.filter(filter_condition) + + base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) + + conversations = base_query.limit(limit).all() has_more = False if len(conversations) == limit: - current_page_first_conversation = conversations[-1] - rest_count = base_query.filter( - Conversation.created_at < current_page_first_conversation.created_at, - Conversation.id != current_page_first_conversation.id - ).count() + current_page_last_conversation = conversations[-1] + rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction, + current_page_last_conversation, is_next_page=True) + rest_count = base_query.filter(rest_filter_condition).count() if rest_count > 0: has_more = True @@ -69,6 +71,21 @@ class ConversationService: has_more=has_more ) + @classmethod + def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: + if sort_by.startswith('-'): + return sort_by[1:], desc + return sort_by, asc + + @classmethod + def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, + is_next_page: bool = False): + field_value = getattr(reference_conversation, sort_field) + if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + return getattr(Conversation, sort_field) < field_value + else: + return getattr(Conversation, sort_field) > field_value + @classmethod def rename(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): @@ -78,6 +95,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return conversation @@ -87,9 +105,9 @@ class ConversationService: # get conversation first message message = db.session.query(Message) \ .filter( - Message.app_id == app_model.id, - Message.conversation_id == conversation.id - ).order_by(Message.created_at.asc()).first() + Message.app_id == app_model.id, + Message.conversation_id == conversation.id + ).order_by(Message.created_at.asc()).first() if not message: raise MessageNotExistsError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 9052a0b785..12ae0e39a8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1429,7 +1429,10 @@ class SegmentService: segment_data_list.append(segment_document) pre_segment_data_list.append(segment_document) - keywords_list.append(segment_item['keywords']) + if 'keywords' in segment_item: + keywords_list.append(segment_item['keywords']) + else: + keywords_list.append(None) try: # save vector index @@ -1482,7 +1485,7 @@ class SegmentService: db.session.add(segment) db.session.commit() # update segment index task - if args['keywords']: + if 'keywords' in args: keyword = Keyword(dataset) keyword.delete_by_ids([segment.index_node_id]) document = RAGDocument( diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 385af685f9..86c059fe90 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -30,6 +30,7 @@ class ModelProviderService: """ Model Provider Service """ + def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -387,18 +388,21 @@ class ModelProviderService: tenant_id=tenant_id, model_type=model_type_enum ) - - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types - ) - ) if result else None + try: + return DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types + ) + ) if result else None + except Exception as e: + logger.info(f"get_default_model_of_model_type error: {e}") + return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: """ diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ea6ecf0c69..ebadbd9be0 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,8 @@ import json import logging +from configs import dify_config +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError @@ -43,14 +45,14 @@ class BuiltinToolManageService: result = [] for tool in tools: result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, + tool=tool, + credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller) )) return result - + @staticmethod def list_builtin_provider_credentials_schema( provider_name @@ -78,7 +80,7 @@ class BuiltinToolManageService: BuiltinToolProvider.provider == provider_name, ).first() - try: + try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: @@ -119,8 +121,8 @@ class BuiltinToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {'result': 'success'} + @staticmethod def get_builtin_tool_provider_credentials( user_id: str, tenant_id: str, provider: str @@ -135,7 +137,7 @@ class BuiltinToolManageService: if provider is None: return {} - + provider_controller = ToolManager.get_builtin_provider(provider.provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -156,7 +158,7 @@ class BuiltinToolManageService: if provider is None: raise ValueError(f'you have not added provider {provider_name}') - + db.session.delete(provider) db.session.commit() @@ -165,8 +167,8 @@ class BuiltinToolManageService: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {'result': 'success'} + @staticmethod def get_builtin_tool_provider_icon( provider: str @@ -179,7 +181,7 @@ class BuiltinToolManageService: icon_bytes = f.read() return icon_bytes, mime_type - + @staticmethod def list_builtin_tools( user_id: str, tenant_id: str @@ -202,6 +204,15 @@ class BuiltinToolManageService: for provider_controller in provider_controllers: try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider_controller, + name_func=lambda x: x.identity.name + ): + continue + # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -226,4 +237,3 @@ class BuiltinToolManageService: raise e return BuiltinToolProviderSort.sort(result) - \ No newline at end of file diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index e89d94160c..185483a71c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -32,6 +32,7 @@ class WorkflowToolManageService: :param description: the description :param parameters: the parameters :param privacy_policy: the privacy policy + :param labels: labels :return: the created tool """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) @@ -92,7 +93,14 @@ class WorkflowToolManageService: Update a workflow tool. :param user_id: the user id :param tenant_id: the tenant id - :param tool: the tool + :param workflow_tool_id: workflow tool id + :param name: name + :param label: label + :param icon: icon + :param description: description + :param parameters: parameters + :param privacy_policy: privacy policy + :param labels: labels :return: the updated tool """ WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index cba048ccdb..269a048134 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -13,7 +13,8 @@ class WebConversationService: @classmethod def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None) -> InfiniteScrollPagination: + pinned: Optional[bool] = None, + sort_by='-updated_at') -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: @@ -36,6 +37,7 @@ class WebConversationService: invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, + sort_by=sort_by ) @classmethod diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f993608293..c9057bd0e5 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -32,11 +32,9 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, - account: Account, - name: str, - icon: str, - icon_background: str) -> App: + def convert_to_workflow( + self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str + ): """ Convert app to workflow @@ -50,22 +48,24 @@ class WorkflowConverter: :param account: Account :param name: new app name :param icon: new app icon + :param icon_type: new app icon type :param icon_background: new app icon background :return: new App instance """ # convert app model config + if not app_model.app_model_config: + raise ValueError("App model config is required") + workflow = self.convert_app_model_config_to_workflow( - app_model=app_model, - app_model_config=app_model.app_model_config, - account_id=account.id + app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id ) # create new app new_app = App() new_app.tenant_id = app_model.tenant_id - new_app.name = name if name else app_model.name + '(workflow)' - new_app.mode = AppMode.ADVANCED_CHAT.value \ - if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.name = name if name else app_model.name + "(workflow)" + new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.icon_type = icon_type if icon_type else app_model.icon_type new_app.icon = icon if icon else app_model.icon new_app.icon_background = icon_background if icon_background else app_model.icon_background new_app.enable_site = app_model.enable_site @@ -85,30 +85,21 @@ class WorkflowConverter: return new_app - def convert_app_model_config_to_workflow(self, app_model: App, - app_model_config: AppModelConfig, - account_id: str) -> Workflow: + def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str): """ Convert app model config to workflow mode :param app_model: App instance :param app_model_config: AppModelConfig instance :param account_id: Account ID - :return: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) # convert app model config - app_config = self._convert_to_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = { - "nodes": [], - "edges": [] - } + graph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -120,11 +111,9 @@ class WorkflowConverter: # - show_retrieve_source -> knowledge-retrieval # convert to start node - start_node = self._convert_to_start_node( - variables=app_config.variables - ) + start_node = self._convert_to_start_node(variables=app_config.variables) - graph['nodes'].append(start_node) + graph["nodes"].append(start_node) # convert to http request node external_data_variable_node_mapping = {} @@ -132,7 +121,7 @@ class WorkflowConverter: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, variables=app_config.variables, - external_data_variables=app_config.external_data_variables + external_data_variables=app_config.external_data_variables, ) for http_request_node in http_request_nodes: @@ -141,9 +130,7 @@ class WorkflowConverter: # convert to knowledge retrieval node if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=app_config.dataset, - model_config=app_config.model + new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model ) if knowledge_retrieval_node: @@ -157,7 +144,7 @@ class WorkflowConverter: model_config=app_config.model, prompt_template=app_config.prompt_template, file_upload=app_config.additional_features.file_upload, - external_data_variable_node_mapping=external_data_variable_node_mapping + external_data_variable_node_mapping=external_data_variable_node_mapping, ) graph = self._append_node(graph, llm_node) @@ -196,11 +183,12 @@ class WorkflowConverter: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(new_app_mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account_id, environment_variables=[], + conversation_variables=[], ) db.session.add(workflow) @@ -208,24 +196,18 @@ class WorkflowConverter: return workflow - def _convert_to_app_config(self, app_model: App, - app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: app_mode = AppMode.value_of(app_model.mode) if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) elif app_mode == AppMode.CHAT: - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) elif app_mode == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) else: raise ValueError("Invalid app mode") @@ -244,14 +226,13 @@ class WorkflowConverter: "data": { "title": "START", "type": NodeType.START.value, - "variables": [jsonable_encoder(v) for v in variables] - } + "variables": [jsonable_encoder(v) for v in variables], + }, } - def _convert_to_http_request_node(self, app_model: App, - variables: list[VariableEntity], - external_data_variables: list[ExternalDataVariableEntity]) \ - -> tuple[list[dict], dict[str, str]]: + def _convert_to_http_request_node( + self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity] + ) -> tuple[list[dict], dict[str, str]]: """ Convert API Based Extension to HTTP Request Node :param app_model: App instance @@ -273,40 +254,33 @@ class WorkflowConverter: # get params from config api_based_extension_id = tool_config.get("api_based_extension_id") + if not api_based_extension_id: + continue # get api_based_extension api_based_extension = self._get_api_based_extension( - tenant_id=tenant_id, - api_based_extension_id=api_based_extension_id + tenant_id=tenant_id, api_based_extension_id=api_based_extension_id ) - if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(tool_variable)) - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key) inputs = {} for v in variables: - inputs[v.variable] = '{{#start.' + v.variable + '#}}' + inputs[v.variable] = "{{#start." + v.variable + "#}}" request_body = { - 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, - 'params': { - 'app_id': app_model.id, - 'tool_variable': tool_variable, - 'inputs': inputs, - 'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else '' - } + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "params": { + "app_id": app_model.id, + "tool_variable": tool_variable, + "inputs": inputs, + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + }, } request_body_json = json.dumps(request_body) - request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}') + request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}") http_request_node = { "id": f"http_request_{index}", @@ -316,20 +290,11 @@ class WorkflowConverter: "type": NodeType.HTTP_REQUEST.value, "method": "post", "url": api_based_extension.api_endpoint, - "authorization": { - "type": "api-key", - "config": { - "type": "bearer", - "api_key": api_key - } - }, + "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, "headers": "", "params": "", - "body": { - "type": "json", - "data": request_body_json - } - } + "body": {"type": "json", "data": request_body_json}, + }, } nodes.append(http_request_node) @@ -341,32 +306,24 @@ class WorkflowConverter: "data": { "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, - "variables": [{ - "variable": "response_json", - "value_selector": [http_request_node['id'], "body"] - }], + "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" - "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", - "outputs": { - "result": { - "type": "string" - } - } - } + 'response_json)\n return {\n "result": response_body["result"]\n }', + "outputs": {"result": {"type": "string"}}, + }, } nodes.append(code_node) - external_data_variable_node_mapping[external_data_variable.variable] = code_node['id'] + external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"] index += 1 return nodes, external_data_variable_node_mapping - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, - dataset_config: DatasetEntity, - model_config: ModelConfigEntity) \ - -> Optional[dict]: + def _convert_to_knowledge_retrieval_node( + self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity + ) -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -400,7 +357,7 @@ class WorkflowConverter: "completion_params": { **model_config.parameters, "stop": model_config.stop, - } + }, } } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE @@ -408,20 +365,23 @@ class WorkflowConverter: "multiple_retrieval_config": { "top_k": retrieve_config.top_k, "score_threshold": retrieve_config.score_threshold, - "reranking_model": retrieve_config.reranking_model + "reranking_model": retrieve_config.reranking_model, } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE else None, - } + }, } - def _convert_to_llm_node(self, original_app_mode: AppMode, - new_app_mode: AppMode, - graph: dict, - model_config: ModelConfigEntity, - prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, - external_data_variable_node_mapping: dict[str, str] = None) -> dict: + def _convert_to_llm_node( + self, + original_app_mode: AppMode, + new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileExtraConfig] = None, + external_data_variable_node_mapping: dict[str, str] | None = None, + ) -> dict: """ Convert to LLM Node :param original_app_mode: original app mode @@ -433,17 +393,18 @@ class WorkflowConverter: :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) - knowledge_retrieval_node = next(filter( - lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, - graph['nodes'] - ), None) + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + knowledge_retrieval_node = next( + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + ) role_prefix = None # Chat Model if model_config.mode == LLMMode.CHAT.value: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -452,45 +413,35 @@ class WorkflowConverter: model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template if not template: prompts = [] else: template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts = [ - { - "role": 'user', - "text": template - } - ] + prompts = [{"role": "user", "text": template}] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template prompts = [] - for m in advanced_chat_prompt_template.messages: - if advanced_chat_prompt_template: + if advanced_chat_prompt_template: + for m in advanced_chat_prompt_template.messages: text = m.text text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + text, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts.append({ - "role": m.role.value, - "text": text - }) + prompts.append({"role": m.role.value, "text": text}) # Completion Model else: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -499,57 +450,50 @@ class WorkflowConverter: model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template=template, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) - prompts = { - "text": template - } + prompts = {"text": template} - prompt_rules = prompt_template_config['prompt_rules'] + prompt_rules = prompt_template_config["prompt_rules"] role_prefix = { - "user": prompt_rules.get('human_prefix', 'Human'), - "assistant": prompt_rules.get('assistant_prefix', 'Assistant') + "user": prompt_rules.get("human_prefix", "Human"), + "assistant": prompt_rules.get("assistant_prefix", "Assistant"), } else: advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template if advanced_completion_prompt_template: text = advanced_completion_prompt_template.prompt text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + template=text, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) else: text = "" - text = text.replace('{{#query#}}', '{{#sys.query#}}') + text = text.replace("{{#query#}}", "{{#sys.query#}}") prompts = { "text": text, } - if advanced_completion_prompt_template.role_prefix: + if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix: role_prefix = { "user": advanced_completion_prompt_template.role_prefix.user, - "assistant": advanced_completion_prompt_template.role_prefix.assistant + "assistant": advanced_completion_prompt_template.role_prefix.assistant, } memory = None if new_app_mode == AppMode.ADVANCED_CHAT: - memory = { - "role_prefix": role_prefix, - "window": { - "enabled": False - } - } + memory = {"role_prefix": role_prefix, "window": {"enabled": False}} completion_params = model_config.parameters completion_params.update({"stop": model_config.stop}) @@ -563,41 +507,42 @@ class WorkflowConverter: "provider": model_config.provider, "name": model_config.model, "mode": model_config.mode, - "completion_params": completion_params + "completion_params": completion_params, }, "prompt_template": prompts, "memory": memory, "context": { "enabled": knowledge_retrieval_node is not None, "variable_selector": ["knowledge_retrieval", "result"] - if knowledge_retrieval_node is not None else None + if knowledge_retrieval_node is not None + else None, }, "vision": { "enabled": file_upload is not None, "variable_selector": ["sys", "files"] if file_upload is not None else None, - "configs": { - "detail": file_upload.image_config['detail'] - } if file_upload is not None else None - } - } + "configs": {"detail": file_upload.image_config["detail"]} + if file_upload is not None and file_upload.image_config is not None + else None, + }, + }, } - def _replace_template_variables(self, template: str, - variables: list[dict], - external_data_variable_node_mapping: dict[str, str] = None) -> str: + def _replace_template_variables( + self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None + ) -> str: """ Replace Template Variables :param template: template :param variables: list of variables + :param external_data_variable_node_mapping: external data variable node mapping :return: """ for v in variables: - template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}") if external_data_variable_node_mapping: for variable, code_node_id in external_data_variable_node_mapping.items(): - template = template.replace('{{' + variable + '}}', - '{{#' + code_node_id + '.result#}}') + template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}") return template @@ -613,11 +558,8 @@ class WorkflowConverter: "data": { "title": "END", "type": NodeType.END.value, - "outputs": [{ - "variable": "result", - "value_selector": ["llm", "text"] - }] - } + "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], + }, } def _convert_to_answer_node(self) -> dict: @@ -629,11 +571,7 @@ class WorkflowConverter: return { "id": "answer", "position": None, - "data": { - "title": "ANSWER", - "type": NodeType.ANSWER.value, - "answer": "{{#llm.text#}}" - } + "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str) -> dict: @@ -643,11 +581,7 @@ class WorkflowConverter: :param target: target node id :return: """ - return { - "id": f"{source}-{target}", - "source": source, - "target": target - } + return {"id": f"{source}-{target}", "source": source, "target": target} def _append_node(self, graph: dict, node: dict) -> dict: """ @@ -657,9 +591,9 @@ class WorkflowConverter: :param node: Node to append :return: """ - previous_node = graph['nodes'][-1] - graph['nodes'].append(node) - graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + previous_node = graph["nodes"][-1] + graph["nodes"].append(node) + graph["edges"].append(self._create_edge(previous_node["id"], node["id"])) return graph def _get_new_app_mode(self, app_model: App) -> AppMode: @@ -673,14 +607,20 @@ class WorkflowConverter: else: return AppMode.ADVANCED_CHAT - def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str): """ Get API Based Extension :param tenant_id: tenant id :param api_based_extension_id: api based extension id :return: """ - return db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}") + + return api_based_extension diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2defb4cd6a..c593b66f36 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -302,6 +302,7 @@ class WorkflowService: app_model=app_model, account=account, name=args.get('name'), + icon_type=args.get('icon_type'), icon=args.get('icon'), icon_background=args.get('icon_background'), ) diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 063ffca6fd..6e6b16045d 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -21,8 +21,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: Add annotation to index. :param job_id: job_id :param content_list: content list - :param tenant_id: tenant id :param app_id: app id + :param tenant_id: tenant id :param user_id: user_id """ diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py new file mode 100644 index 0000000000..7b3ff82727 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.rerank.rerank import SiliconflowRerankModel + + +def test_validate_credentials(): + model = SiliconflowRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": "invalid_key" + }, + ) + + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SiliconflowRerankModel() + + result = model.invoke( + model='BAAI/bge-reranker-v2-m3', + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty." + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py new file mode 100644 index 0000000000..d886226cf9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -0,0 +1,81 @@ +import os +from time import sleep + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel + + +def test_invoke_embedding_v1(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='embedding-v1', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_en(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='bge-large-en', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_zh(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='bge-large-zh', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_tao_8k(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='tao-8k', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py index 851599c7ce..c5a986b747 100644 --- a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py +++ b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py @@ -21,10 +21,6 @@ class PGVectorTest(AbstractVectorTest): ), ) - def search_by_full_text(self): - hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) - assert len(hits_by_full_text) == 0 - def test_pgvector(setup_mock_redis): PGVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 4686ce0675..1b27af5af7 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,7 +11,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db @@ -66,10 +66,10 @@ def test_execute_llm(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather today?', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['abc', 'output'], 'sunny') @@ -181,10 +181,10 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather today?', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather today?', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['abc', 'output'], 'sunny') diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index adf5ffe3ca..e32fa59df3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -13,7 +13,7 @@ from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db @@ -119,10 +119,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -177,10 +177,10 @@ def test_instructions(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -243,10 +243,10 @@ def test_chat_parameter_extractor(setup_anthropic_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -307,10 +307,10 @@ def test_completion_parameter_extractor(setup_openai_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) @@ -420,10 +420,10 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.QUERY: 'what\'s the weather in SF', - SystemVariable.FILES: [], - SystemVariable.CONVERSATION_ID: 'abababa', - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.QUERY: 'what\'s the weather in SF', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) result = node.run(pool) diff --git a/api/tests/unit_tests/core/app/segments/test_segment.py b/api/tests/unit_tests/core/app/segments/test_segment.py index 7e3e69ffbf..50d991316d 100644 --- a/api/tests/unit_tests/core/app/segments/test_segment.py +++ b/api/tests/unit_tests/core/app/segments/test_segment.py @@ -1,13 +1,13 @@ from core.app.segments import SecretVariable, StringSegment, parser from core.helper import encrypter from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey def test_segment_group_to_text(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey('user_id'): 'fake-user-id', }, user_inputs={}, environment_variables=[ @@ -42,7 +42,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): variable_pool = VariablePool( system_variables={ - SystemVariable('user_id'): 'fake-user-id', + SystemVariableKey('user_id'): 'fake-user-id', }, user_inputs={}, environment_variables=[], diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/tests/unit_tests/core/model_runtime/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py new file mode 100644 index 0000000000..68334fde82 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py @@ -0,0 +1,75 @@ +import numpy as np + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import ( + TextEmbedding, + WenxinTextEmbeddingModel, +) + + +def test_max_chunks(): + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += len(text) + + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + model = 'embedding-v1' + credentials = { + 'api_key': 'xxxx', + 'secret_key': 'yyyy', + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + max_chunks = embedding_model._get_max_chunks(model, credentials) + embedding_model._create_text_embedding = _create_text_embedding + + texts = ['0123456789' for i in range(0, max_chunks * 2)] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + assert len(result.embeddings) == max_chunks * 2 + + +def test_context_size(): + def get_num_tokens_by_gpt2(text: str) -> int: + return GPT2Tokenizer.get_num_tokens(text) + + def mock_text(token_size: int) -> str: + _text = "".join(['0' for i in range(token_size)]) + num_tokens = get_num_tokens_by_gpt2(_text) + ratio = int(np.floor(len(_text) / num_tokens)) + m_text = "".join([_text for i in range(ratio)]) + return m_text + + model = 'embedding-v1' + credentials = { + 'api_key': 'xxxx', + 'secret_key': 'yyyy', + } + embedding_model = WenxinTextEmbeddingModel() + context_size = embedding_model._get_context_size(model, credentials) + + class _MockTextEmbedding(TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] + tokens = 0 + for text in texts: + tokens += get_num_tokens_by_gpt2(text) + return embeddings, tokens, tokens + + def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: + return _MockTextEmbedding() + + embedding_model._create_text_embedding = _create_text_embedding + text = mock_text(context_size * 2) + assert get_num_tokens_by_gpt2(text) == context_size * 2 + + texts = [text] + result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') + assert result.usage.tokens == context_size diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 4617b6a42f..44b7c85256 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import UserFrom from extensions.ext_database import db @@ -29,8 +29,8 @@ def test_execute_answer(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'weather'], 'sunny') pool.add(['llm', 'text'], 'You are a helpful AI.') diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index d21b7785c4..87ebcb34e6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db @@ -119,8 +119,8 @@ def test_execute_if_else_result_true(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'array_contains'], ['ab', 'def']) pool.add(['start', 'array_not_contains'], ['ac', 'def']) @@ -182,8 +182,8 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' + SystemVariableKey.FILES: [], + SystemVariableKey.USER_ID: 'aaa' }, user_inputs={}, environment_variables=[]) pool.add(['start', 'array_contains'], ['1ab', 'def']) pool.add(['start', 'array_not_contains'], ['ab', 'def']) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py index 0b37d06fc0..5df8c1b763 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py @@ -4,7 +4,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.segments import ArrayStringVariable, StringVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable +from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode @@ -42,7 +42,7 @@ def test_overwrite_string_variable(): ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -52,7 +52,7 @@ def test_overwrite_string_variable(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: node.run(variable_pool) mock_run.assert_called_once() @@ -93,7 +93,7 @@ def test_append_variable_to_array(): ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -103,7 +103,7 @@ def test_append_variable_to_array(): input_variable, ) - with mock.patch('core.workflow.nodes.variable_assigner.update_conversation_variable') as mock_run: + with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run: node.run(variable_pool) mock_run.assert_called_once() @@ -137,7 +137,7 @@ def test_clear_array(): ) variable_pool = VariablePool( - system_variables={SystemVariable.CONVERSATION_ID: 'conversation_id'}, + system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'}, user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index facea34b5b..bea896b83a 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -11,7 +11,17 @@ def test_environment_variables(): contexts.tenant_id.set('tenant_id') # Create a Workflow instance - workflow = Workflow() + workflow = Workflow( + tenant_id='tenant_id', + app_id='app_id', + type='workflow', + version='draft', + graph='{}', + features='{}', + created_by='account_id', + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) @@ -35,7 +45,17 @@ def test_update_environment_variables(): contexts.tenant_id.set('tenant_id') # Create a Workflow instance - workflow = Workflow() + workflow = Workflow( + tenant_id='tenant_id', + app_id='app_id', + type='workflow', + version='draft', + graph='{}', + features='{}', + created_by='account_id', + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())}) @@ -70,9 +90,17 @@ def test_to_dict(): contexts.tenant_id.set('tenant_id') # Create a Workflow instance - workflow = Workflow() - workflow.graph = '{}' - workflow.features = '{}' + workflow = Workflow( + tenant_id='tenant_id', + app_id='app_id', + type='workflow', + version='draft', + graph='{}', + features='{}', + created_by='account_id', + environment_variables=[], + conversation_variables=[], + ) # Create some EnvironmentVariable instances diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index f589cd2097..a45423bf39 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( ModelConfigEntity, PromptTemplateEntity, VariableEntity, + VariableEntityType, ) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode @@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter @pytest.fixture def default_variables(): - return [ + value = [ VariableEntity( variable="text_input", label="text-input", - type=VariableEntity.Type.TEXT_INPUT + type=VariableEntityType.TEXT_INPUT, ), VariableEntity( variable="paragraph", label="paragraph", - type=VariableEntity.Type.PARAGRAPH + type=VariableEntityType.PARAGRAPH, ), VariableEntity( variable="select", label="select", - type=VariableEntity.Type.SELECT - ) + type=VariableEntityType.SELECT, + ), ] + return value def test__convert_to_start_node(default_variables): diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py index 2237319904..1235e559c9 100644 --- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py +++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py @@ -2,7 +2,7 @@ from textwrap import dedent import pytest -from core.helper.position_helper import get_position_map +from core.helper.position_helper import get_position_map, is_filtered, pin_position_map, sort_by_position_map @pytest.fixture @@ -14,7 +14,7 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str: - second # - commented - third - + - 9999999999999 - forth """)) @@ -28,9 +28,9 @@ def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str: """\ # - commented1 # - commented2 - - - - - + - + - + """)) return str(tmp_path) @@ -53,3 +53,79 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya folder_path=prepare_empty_commented_positions_yaml, file_name='example_positions_all_commented.yaml') assert position_map == {} + + +def test_excluded_position_data(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml' + ) + pin_list = ['forth', 'first'] + include_set = set() + exclude_set = {'9999999999999'} + + position_map = pin_position_map( + original_position_map=position_map, + pin_list=pin_list + ) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2'] + + +def test_included_position_data(prepare_example_positions_yaml): + position_map = get_position_map( + folder_path=prepare_example_positions_yaml, + file_name='example_positions.yaml' + ) + pin_list = ['forth', 'first'] + include_set = {'forth', 'first'} + exclude_set = {} + + position_map = pin_position_map( + original_position_map=position_map, + pin_list=pin_list + ) + + data = [ + "forth", + "first", + "second", + "third", + "9999999999999", + "extra1", + "extra2", + ] + + # filter out the data + data = [item for item in data if not is_filtered(include_set, exclude_set, item, lambda x: x)] + + # sort data by position map + sorted_data = sort_by_position_map( + position_map=position_map, + data=data, + name_func=lambda x: x, + ) + + # assert the result in the correct order + assert sorted_data == ['forth', 'first'] diff --git a/dev/reformat b/dev/reformat index f50ccb04c4..ad83e897d9 100755 --- a/dev/reformat +++ b/dev/reformat @@ -11,5 +11,8 @@ fi # run ruff linter ruff check --fix ./api +# run ruff formatter +ruff format ./api + # run dotenv-linter linter dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index aed2586053..edefd129d5 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.7.0 + image: langgenius/dify-api:0.7.1 restart: always environment: # Startup mode, 'api' starts the API server. @@ -229,7 +229,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.0 + image: langgenius/dify-api:0.7.1 restart: always environment: CONSOLE_WEB_URL: '' @@ -400,7 +400,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.0 + image: langgenius/dify-web:0.7.1 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 6fee8b4b3c..f530f27a47 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -303,7 +303,7 @@ TENCENT_COS_SCHEME=your-scheme # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`. VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -396,6 +396,12 @@ TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_REPLICAS=2 +# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` +ELASTICSEARCH_HOST=0.0.0.0 +ELASTICSEARCH_PORT=9200 +ELASTICSEARCH_USERNAME=elastic +ELASTICSEARCH_PASSWORD=elastic + # ------------------------------ # Knowledge Configuration # ------------------------------ @@ -508,6 +514,7 @@ RESET_PASSWORD_TOKEN_EXPIRY_HOURS=24 CODE_EXECUTION_ENDPOINT=http://sandbox:8194 CODE_MAX_NUMBER=9223372036854775807 CODE_MIN_NUMBER=-9223372036854775808 +CODE_MAX_DEPTH=5 CODE_MAX_STRING_LENGTH=80000 TEMPLATE_TRANSFORM_MAX_LENGTH=80000 CODE_MAX_STRING_ARRAY_LENGTH=30 @@ -695,3 +702,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} # ------------------------------ EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 + +# ---------------------------------------------------------------------------- +# ModelProvider & Tool Position Configuration +# Used to specify the model providers and tools that can be used in the app. +# ---------------------------------------------------------------------------- + +# Pin, include, and exclude tools +# Use comma-separated values with no spaces between items. +# Example: POSITION_TOOL_PINS=bing,google +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +# Pin, include, and exclude model providers +# Use comma-separated values with no spaces between items. +# Example: POSITION_PROVIDER_PINS=openai,openllm +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= \ No newline at end of file diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 3aa84d009e..9a4c40448b 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -34,7 +34,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.1 + image: langgenius/dify-sandbox:0.2.6 restart: always environment: # The DifySandbox configurations diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index f3151bbc2a..b9fa5e6786 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,5 +1,6 @@ x-shared-env: &shared-api-worker-env LOG_LEVEL: ${LOG_LEVEL:-INFO} + LOG_FILE: ${LOG_FILE:-} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} @@ -35,11 +36,6 @@ x-shared-env: &shared-api-worker-env SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} - POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100} - POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB} - POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} - POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB} - POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB} REDIS_HOST: ${REDIS_HOST:-redis} REDIS_PORT: ${REDIS_PORT:-6379} REDIS_USERNAME: ${REDIS_USERNAME:-} @@ -125,10 +121,11 @@ x-shared-env: &shared-api-worker-env CHROMA_DATABASE: ${CHROMA_DATABASE:-default_database} CHROMA_AUTH_PROVIDER: ${CHROMA_AUTH_PROVIDER:-chromadb.auth.token_authn.TokenAuthClientProvider} CHROMA_AUTH_CREDENTIALS: ${CHROMA_AUTH_CREDENTIALS:-} - ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-127.0.0.1} + ELASTICSEARCH_HOST: ${ELASTICSEARCH_HOST:-0.0.0.0} ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-} @@ -180,6 +177,7 @@ x-shared-env: &shared-api-worker-env CODE_EXECUTION_API_KEY: ${SANDBOX_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} + CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5} CODE_MAX_STRING_LENGTH: ${CODE_MAX_STRING_LENGTH:-80000} TEMPLATE_TRANSFORM_MAX_LENGTH: ${TEMPLATE_TRANSFORM_MAX_LENGTH:-80000} CODE_MAX_STRING_ARRAY_LENGTH: ${CODE_MAX_STRING_ARRAY_LENGTH:-30} @@ -191,7 +189,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.7.0 + image: langgenius/dify-api:0.7.1 restart: always environment: # Use the shared environment variables. @@ -211,7 +209,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.7.0 + image: langgenius/dify-api:0.7.1 restart: always environment: # Use the shared environment variables. @@ -230,7 +228,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.7.0 + image: langgenius/dify-web:0.7.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -275,7 +273,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.1 + image: langgenius/dify-sandbox:0.2.6 restart: always environment: # The DifySandbox configurations @@ -599,27 +597,61 @@ services: ports: - "${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}" + # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html + # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites elasticsearch: image: docker.elastic.co/elasticsearch/elasticsearch:8.14.3 container_name: elasticsearch profiles: - elasticsearch restart: always + volumes: + - dify_es01_data:/usr/share/elasticsearch/data environment: - - "ELASTIC_PASSWORD=${ELASTICSEARCH_USERNAME:-elastic}" - - "cluster.name=dify-es-cluster" - - "node.name=dify-es0" - - "discovery.type=single-node" - - "xpack.security.http.ssl.enabled=false" - - "xpack.license.self_generated.type=trial" + - ELASTIC_PASSWORD=${ELASTICSEARCH_PASSWORD:-elastic} + - cluster.name=dify-es-cluster + - node.name=dify-es0 + - discovery.type=single-node + - xpack.license.self_generated.type=trial + - xpack.security.enabled=true + - xpack.security.enrollment.enabled=false + - xpack.security.http.ssl.enabled=false ports: - - "${ELASTICSEARCH_PORT:-9200}:${ELASTICSEARCH_PORT:-9200}" + - ${ELASTICSEARCH_PORT:-9200}:9200 healthcheck: test: ["CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty"] interval: 30s timeout: 10s retries: 50 + # https://www.elastic.co/guide/en/kibana/current/docker.html + # https://www.elastic.co/guide/en/kibana/current/settings.html + kibana: + image: docker.elastic.co/kibana/kibana:8.14.3 + container_name: kibana + profiles: + - elasticsearch + depends_on: + - elasticsearch + restart: always + environment: + - XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY=d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa + - NO_PROXY=localhost,127.0.0.1,elasticsearch,kibana + - XPACK_SECURITY_ENABLED=true + - XPACK_SECURITY_ENROLLMENT_ENABLED=false + - XPACK_SECURITY_HTTP_SSL_ENABLED=false + - XPACK_FLEET_ISAIRGAPPED=true + - I18N_LOCALE=zh-CN + - SERVER_PORT=5601 + - ELASTICSEARCH_HOSTS="http://elasticsearch:9200" + ports: + - ${KIBANA_PORT:-5601}:5601 + healthcheck: + test: [ "CMD-SHELL", "curl -s http://localhost:5601 >/dev/null || exit 1" ] + interval: 30s + timeout: 10s + retries: 3 + # unstructured . # (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.) unstructured: @@ -643,3 +675,4 @@ networks: volumes: oradata: + dify_es01_data: diff --git a/sdks/nodejs-client/index.d.ts b/sdks/nodejs-client/index.d.ts index 7fdd943f63..a2e6e50aef 100644 --- a/sdks/nodejs-client/index.d.ts +++ b/sdks/nodejs-client/index.d.ts @@ -33,6 +33,10 @@ export declare class DifyClient { getApplicationParameters(user: User): Promise; fileUpload(data: FormData): Promise; + + textToAudio(text: string ,user: string, streaming?: boolean): Promise; + + getMeta(user: User): Promise; } export declare class CompletionClient extends DifyClient { @@ -54,6 +58,18 @@ export declare class ChatClient extends DifyClient { files?: File[] | null ): Promise; + getSuggested(message_id: string, user: User): Promise; + + stopMessage(task_id: string, user: User) : Promise; + + + getConversations( + user: User, + first_id?: string | null, + limit?: number | null, + pinned?: boolean | null + ): Promise; + getConversationMessages( user: User, conversation_id?: string, @@ -61,9 +77,15 @@ export declare class ChatClient extends DifyClient { limit?: number | null ): Promise; - getConversations(user: User, first_id?: string | null, limit?: number | null, pinned?: boolean | null): Promise; - - renameConversation(conversation_id: string, name: string, user: User): Promise; + renameConversation(conversation_id: string, name: string, user: User,auto_generate:boolean): Promise; deleteConversation(conversation_id: string, user: User): Promise; + + audioToText(data: FormData): Promise; +} + +export declare class WorkflowClient extends DifyClient { + run(inputs: any, user: User, stream?: boolean,): Promise; + + stop(task_id: string, user: User): Promise; } \ No newline at end of file diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js index 93491fa16b..858241ce5a 100644 --- a/sdks/nodejs-client/index.js +++ b/sdks/nodejs-client/index.js @@ -2,30 +2,55 @@ import axios from "axios"; export const BASE_URL = "https://api.dify.ai/v1"; export const routes = { - application: { - method: "GET", - url: () => `/parameters`, - }, + // app's feedback: { method: "POST", url: (message_id) => `/messages/${message_id}/feedbacks`, }, + application: { + method: "GET", + url: () => `/parameters`, + }, + fileUpload: { + method: "POST", + url: () => `/files/upload`, + }, + textToAudio: { + method: "POST", + url: () => `/text-to-audio`, + }, + getMeta: { + method: "GET", + url: () => `/meta`, + }, + + // completion's createCompletionMessage: { method: "POST", url: () => `/completion-messages`, }, + + // chat's createChatMessage: { method: "POST", url: () => `/chat-messages`, }, - getConversationMessages: { + getSuggested:{ method: "GET", - url: () => `/messages`, + url: (message_id) => `/messages/${message_id}/suggested`, + }, + stopChatMessage: { + method: "POST", + url: (task_id) => `/chat-messages/${task_id}/stop`, }, getConversations: { method: "GET", url: () => `/conversations`, }, + getConversationMessages: { + method: "GET", + url: () => `/messages`, + }, renameConversation: { method: "POST", url: (conversation_id) => `/conversations/${conversation_id}/name`, @@ -34,14 +59,21 @@ export const routes = { method: "DELETE", url: (conversation_id) => `/conversations/${conversation_id}`, }, - fileUpload: { + audioToText: { method: "POST", - url: () => `/files/upload`, + url: () => `/audio-to-text`, }, + + // workflow‘s runWorkflow: { method: "POST", url: () => `/workflows/run`, }, + stopWorkflow: { + method: "POST", + url: (task_id) => `/workflows/${task_id}/stop`, + } + }; export class DifyClient { @@ -129,6 +161,31 @@ export class DifyClient { } ); } + + textToAudio(text, user, streaming = false) { + const data = { + text, + user, + streaming + }; + return this.sendRequest( + routes.textToAudio.method, + routes.textToAudio.url(), + data, + null, + streaming + ); + } + + getMeta(user) { + const params = { user }; + return this.sendRequest( + routes.meta.method, + routes.meta.url(), + null, + params + ); + } } export class CompletionClient extends DifyClient { @@ -191,6 +248,34 @@ export class ChatClient extends DifyClient { ); } + getSuggested(message_id, user) { + const data = { user }; + return this.sendRequest( + routes.getSuggested.method, + routes.getSuggested.url(message_id), + data + ); + } + + stopMessage(task_id, user) { + const data = { user }; + return this.sendRequest( + routes.stopChatMessage.method, + routes.stopChatMessage.url(task_id), + data + ); + } + + getConversations(user, first_id = null, limit = null, pinned = null) { + const params = { user, first_id: first_id, limit, pinned }; + return this.sendRequest( + routes.getConversations.method, + routes.getConversations.url(), + null, + params + ); + } + getConversationMessages( user, conversation_id = "", @@ -213,16 +298,6 @@ export class ChatClient extends DifyClient { ); } - getConversations(user, first_id = null, limit = null, pinned = null) { - const params = { user, first_id: first_id, limit, pinned }; - return this.sendRequest( - routes.getConversations.method, - routes.getConversations.url(), - null, - params - ); - } - renameConversation(conversation_id, name, user, auto_generate) { const data = { name, user, auto_generate }; return this.sendRequest( @@ -240,4 +315,46 @@ export class ChatClient extends DifyClient { data ); } + + + audioToText(data) { + return this.sendRequest( + routes.audioToText.method, + routes.audioToText.url(), + data, + null, + false, + { + "Content-Type": 'multipart/form-data' + } + ); + } + +} + +export class WorkflowClient extends DifyClient { + run(inputs,user,stream) { + const data = { + inputs, + response_mode: stream ? "streaming" : "blocking", + user + }; + + return this.sendRequest( + routes.runWorkflow.method, + routes.runWorkflow.url(), + data, + null, + stream + ); + } + + stop(task_id, user) { + const data = { user }; + return this.sendRequest( + routes.stopWorkflow.method, + routes.stopWorkflow.url(task_id), + data + ); + } } \ No newline at end of file diff --git a/sdks/php-client/dify-client.php b/sdks/php-client/dify-client.php index c5ddcc3a87..ccd61f091a 100644 --- a/sdks/php-client/dify-client.php +++ b/sdks/php-client/dify-client.php @@ -9,9 +9,9 @@ class DifyClient { protected $base_url; protected $client; - public function __construct($api_key) { + public function __construct($api_key, $base_url = null) { $this->api_key = $api_key; - $this->base_url = "https://api.dify.ai/v1/"; + $this->base_url = $base_url ?? "https://api.dify.ai/v1/"; $this->client = new Client([ 'base_uri' => $this->base_url, 'headers' => [ @@ -80,6 +80,25 @@ class DifyClient { return $multipart; } + + + public function text_to_audio($text, $user, $streaming = false) { + $data = [ + 'text' => $text, + 'user' => $user, + 'streaming' => $streaming + ]; + + return $this->send_request('POST', 'text-to-audio', $data); + } + + public function get_meta($user) { + $params = [ + 'user' => $user + ]; + + return $this->send_request('GET', 'meta',null, $params); + } } class CompletionClient extends DifyClient { @@ -109,6 +128,28 @@ class ChatClient extends DifyClient { return $this->send_request('POST', 'chat-messages', $data, null, $response_mode === 'streaming'); } + + public function get_suggestions($message_id, $user) { + $params = [ + 'user' => $user + ] + return $this->send_request('GET', "messages/{$message_id}/suggested", null, $params); + } + + public function stop_message($task_id, $user) { + $data = ['user' => $user]; + return $this->send_request('POST', "chat-messages/{$task_id}/stop", $data); + } + + public function get_conversations($user, $first_id = null, $limit = null, $pinned = null) { + $params = [ + 'user' => $user, + 'first_id' => $first_id, + 'limit' => $limit, + 'pinned'=> $pinned, + ]; + return $this->send_request('GET', 'conversations', null, $params); + } public function get_conversation_messages($user, $conversation_id = null, $first_id = null, $limit = null) { $params = ['user' => $user]; @@ -125,22 +166,51 @@ class ChatClient extends DifyClient { return $this->send_request('GET', 'messages', null, $params); } - - public function get_conversations($user, $first_id = null, $limit = null, $pinned = null) { - $params = [ - 'user' => $user, - 'first_id' => $first_id, - 'limit' => $limit, - 'pinned'=> $pinned, - ]; - return $this->send_request('GET', 'conversations', null, $params); - } - - public function rename_conversation($conversation_id, $name, $user) { + + public function rename_conversation($conversation_id, $name,$auto_generate, $user) { $data = [ 'name' => $name, 'user' => $user, + 'auto_generate' => $auto_generate ]; return $this->send_request('PATCH', "conversations/{$conversation_id}", $data); } + + public function delete_conversation($conversation_id, $user) { + $data = [ + 'user' => $user, + ]; + return $this->send_request('DELETE', "conversations/{$conversation_id}", $data); + } + + public function audio_to_text($audio_file, $user) { + $data = [ + 'user' => $user, + ]; + $options = [ + 'multipart' => $this->prepareMultipart($data, $files) + ]; + return $this->file_client->request('POST', 'audio-to-text', $options); + + } + } + +class WorkflowClient extends DifyClient{ + public function run($inputs, $response_mode, $user) { + $data = [ + 'inputs' => $inputs, + 'response_mode' => $response_mode, + 'user' => $user, + ]; + return $this->send_request('POST', 'workflows/run', $data); + } + + public function stop($task_id, $user) { + $data = [ + 'user' => $user, + ]; + return $this->send_request('POST', "workflows/tasks/{$task_id}/stop",$data); + } + +} \ No newline at end of file diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index 22d4e6189a..a8da1d7cae 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -16,6 +16,7 @@ class DifyClient: response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream) return response + def _send_request_with_files(self, method, endpoint, data, files): headers = { @@ -26,24 +27,36 @@ class DifyClient: response = requests.request(method, url, data=data, headers=headers, files=files) return response - + def message_feedback(self, message_id, rating, user): data = { "rating": rating, "user": user } return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) - + def get_application_parameters(self, user): params = {"user": user} return self._send_request("GET", "/parameters", params=params) - + def file_upload(self, user, files): data = { "user": user } return self._send_request_with_files("POST", "/files/upload", data=data, files=files) + def text_to_audio(self, text:str, user:str, streaming:bool=False): + data = { + "text": text, + "user": user, + "streaming": streaming + } + return self._send_request("POST", "/text-to-audio", data=data) + + def get_meta(self,user): + params = { "user": user} + return self._send_request("GET", f"/meta", params=params) + class CompletionClient(DifyClient): def create_completion_message(self, inputs, response_mode, user, files=None): @@ -71,7 +84,19 @@ class ChatClient(DifyClient): return self._send_request("POST", "/chat-messages", data, stream=True if response_mode == "streaming" else False) + + def get_suggested(self, message_id, user:str): + params = {"user": user} + return self._send_request("GET", f"/messages/{message_id}/suggested", params=params) + + def stop_message(self, task_id, user): + data = {"user": user} + return self._send_request("POST", f"/chat-messages/{task_id}/stop", data) + def get_conversations(self, user, last_id=None, limit=None, pinned=None): + params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} + return self._send_request("GET", "/conversations", params=params) + def get_conversation_messages(self, user, conversation_id=None, first_id=None, limit=None): params = {"user": user} @@ -83,11 +108,26 @@ class ChatClient(DifyClient): params["limit"] = limit return self._send_request("GET", "/messages", params=params) - - def get_conversations(self, user, last_id=None, limit=None, pinned=None): - params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned} - return self._send_request("GET", "/conversations", params=params) - - def rename_conversation(self, conversation_id, name, user): - data = {"name": name, "user": user} + + def rename_conversation(self, conversation_id, name,auto_generate:bool, user:str): + data = {"name": name, "auto_generate": auto_generate,"user": user} return self._send_request("POST", f"/conversations/{conversation_id}/name", data) + + def delete_conversation(self, conversation_id, user): + data = {"user": user} + return self._send_request("DELETE", f"/conversations/{conversation_id}", data) + + def audio_to_text(self, audio_file, user): + data = {"user": user} + files = {"audio_file": audio_file} + return self._send_request_with_files("POST", "/audio-to-text", data, files) + + +class WorkflowClient(DifyClient): + def run(self, inputs:dict, response_mode:str="streaming", user:str="abc-123"): + data = {"inputs": inputs, "response_mode": response_mode, "user": user} + return self._send_request("POST", "/workflows/run", data) + + def stop(self, task_id, user): + data = {"user": user} + return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data) diff --git a/web/.env.example b/web/.env.example index 439092c20e..3045cef2f9 100644 --- a/web/.env.example +++ b/web/.env.example @@ -15,4 +15,7 @@ NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api NEXT_PUBLIC_SENTRY_DSN= # Disable Next.js Telemetry (https://nextjs.org/telemetry) -NEXT_TELEMETRY_DISABLED=1 \ No newline at end of file +NEXT_TELEMETRY_DISABLED=1 + +# Disable Upload Image as WebApp icon default is false +NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx index 74e9140fe5..ff32a157fc 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -25,7 +25,7 @@ export default function ChartView({ appId }: IChartViewProps) { const appDetail = useAppStore(state => state.appDetail) const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow' const isWorkflow = appDetail?.mode === 'workflow' - const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }) const onSelect = (item: Item) => { if (item.value === 'all') { @@ -37,7 +37,7 @@ export default function ChartView({ appId }: IChartViewProps) { setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } }) } else { - setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').startOf('day').format(queryDateFormat), end: today.endOf('day').format(queryDateFormat) } }) } } diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index 1387099a62..bc7308a711 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -75,6 +75,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, + icon_type, icon, icon_background, description, @@ -83,6 +84,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { await updateAppInfo({ appID: app.id, name, + icon_type, icon, icon_background, description, @@ -101,11 +103,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } }, [app.id, mutateApps, notify, onRefresh, t]) - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => { + const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { try { const newApp = await copyApp({ appID: app.id, name, + icon_type, icon, icon_background, mode: app.mode, @@ -258,8 +261,10 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
{app.mode === 'advanced-chat' && ( @@ -360,9 +365,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { {showEditModal && ( { {showDuplicateModal && ( setShowDuplicateModal(false)} diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index 11893ec9de..7cd083d11b 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -60,7 +60,7 @@ const LikedItem = ({ return (
- + {type === 'app' && ( {detail.mode === 'advanced-chat' && ( diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 698846cae5..c5bc3bb210 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -59,6 +59,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, + icon_type, icon, icon_background, description, @@ -69,6 +70,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const app = await updateAppInfo({ appID: appDetail.id, name, + icon_type, icon, icon_background, description, @@ -86,13 +88,14 @@ const AppInfo = ({ expand }: IAppInfoProps) => { } }, [appDetail, mutateApps, notify, setAppDetail, t]) - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => { + const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { if (!appDetail) return try { const newApp = await copyApp({ appID: appDetail.id, name, + icon_type, icon, icon_background, mode: appDetail.mode, @@ -194,7 +197,13 @@ const AppInfo = ({ expand }: IAppInfoProps) => { >
- + { {/* header */}
- + {appDetail.mode === 'advanced-chat' && ( @@ -402,9 +417,11 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {showEditModal && ( { {showDuplicateModal && ( setShowDuplicateModal(false)} diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index ccc2c1a6d5..6204e9569a 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -8,6 +8,8 @@ import { } from '@remixicon/react' import { useRouter } from 'next/navigation' import { useContext, useContextSelector } from 'use-context-selector' +import AppIconPicker from '../../base/app-icon-picker' +import type { AppIconSelection } from '../../base/app-icon-picker' import s from './style.module.css' import cn from '@/utils/classnames' import AppsContext, { useAppContext } from '@/context/app-context' @@ -20,7 +22,6 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' import AppIcon from '@/app/components/base/app-icon' -import EmojiPicker from '@/app/components/base/emoji-picker' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' @@ -42,8 +43,8 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { const [appMode, setAppMode] = useState('chat') const [showChatBotType, setShowChatBotType] = useState(true) - const [emoji, setEmoji] = useState({ icon: '🤖', icon_background: '#FFEAD5' }) - const [showEmojiPicker, setShowEmojiPicker] = useState(false) + const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [name, setName] = useState('') const [description, setDescription] = useState('') @@ -68,8 +69,9 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { const app = await createApp({ name, description, - icon: emoji.icon, - icon_background: emoji.icon_background, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, mode: appMode, }) notify({ type: 'success', message: t('app.newApp.appCreated') }) @@ -83,7 +85,7 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) } isCreatingRef.current = false - }, [name, notify, t, appMode, emoji.icon, emoji.icon_background, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor]) + }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor]) return ( {
{t('app.newApp.captionName')}
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.icon} background={emoji.icon_background} /> + { setShowAppIconPicker(true) }} + /> setName(e.target.value)} @@ -279,14 +288,13 @@ const CreateAppModal = ({ show, onSuccess, onClose }: CreateAppDialogProps) => { className='grow h-10' />
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon: '🤖', icon_background: '#FFEAD5' }) - setShowEmojiPicker(false) + setShowAppIconPicker(false) }} />}
diff --git a/web/app/components/app/duplicate-modal/index.tsx b/web/app/components/app/duplicate-modal/index.tsx index 10a79d84e2..bcad1c24f2 100644 --- a/web/app/components/app/duplicate-modal/index.tsx +++ b/web/app/components/app/duplicate-modal/index.tsx @@ -1,6 +1,7 @@ 'use client' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' +import AppIconPicker from '../../base/app-icon-picker' import s from './style.module.css' import cn from '@/utils/classnames' import Modal from '@/app/components/base/modal' @@ -8,26 +9,32 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import AppIcon from '@/app/components/base/app-icon' -import EmojiPicker from '@/app/components/base/emoji-picker' import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' +import type { AppIconType } from '@/types/app' + export type DuplicateAppModalProps = { appName: string + icon_type: AppIconType | null icon: string - icon_background: string + icon_background?: string | null + icon_url?: string | null show: boolean onConfirm: (info: { name: string + icon_type: AppIconType icon: string - icon_background: string + icon_background?: string | null }) => Promise onHide: () => void } const DuplicateAppModal = ({ appName, + icon_type, icon, icon_background, + icon_url, show = false, onConfirm, onHide, @@ -36,8 +43,12 @@ const DuplicateAppModal = ({ const [name, setName] = React.useState(appName) - const [showEmojiPicker, setShowEmojiPicker] = useState(false) - const [emoji, setEmoji] = useState({ icon, icon_background }) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [appIcon, setAppIcon] = useState( + icon_type === 'image' + ? { type: 'image' as const, url: icon_url, fileId: icon } + : { type: 'emoji' as const, icon, background: icon_background }, + ) const { plan, enableBilling } = useProviderContext() const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) @@ -49,7 +60,9 @@ const DuplicateAppModal = ({ } onConfirm({ name, - ...emoji, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, }) onHide() } @@ -66,7 +79,15 @@ const DuplicateAppModal = ({
{t('explore.appCustomize.subTitle')}
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.icon} background={emoji.icon_background} /> + { setShowAppIconPicker(true) }} + className='cursor-pointer' + iconType={appIcon.type} + icon={appIcon.type === 'image' ? appIcon.fileId : appIcon.icon} + background={appIcon.type === 'image' ? undefined : appIcon.background} + imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} + /> setName(e.target.value)} @@ -80,14 +101,16 @@ const DuplicateAppModal = ({
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) + setShowAppIconPicker(false) }} />} diff --git a/web/app/components/app/log/filter.tsx b/web/app/components/app/log/filter.tsx index 80a58bb5a2..0552b44d16 100644 --- a/web/app/components/app/log/filter.tsx +++ b/web/app/components/app/log/filter.tsx @@ -10,6 +10,7 @@ import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import type { QueryParam } from './index' import { SimpleSelect } from '@/app/components/base/select' +import Sort from '@/app/components/base/sort' import { fetchAnnotationsCount } from '@/service/log' dayjs.extend(quarterOfYear) @@ -28,18 +29,19 @@ export const TIME_PERIOD_LIST = [ ] type IFilterProps = { + isChatMode?: boolean appId: string queryParams: QueryParam setQueryParams: (v: QueryParam) => void } -const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilterProps) => { +const Filter: FC = ({ isChatMode, appId, queryParams, setQueryParams }: IFilterProps) => { const { data } = useSWR({ url: `/apps/${appId}/annotations/count` }, fetchAnnotationsCount) const { t } = useTranslation() if (!data) return null return ( -
+
({ value: item.value, name: t(`appLog.filter.period.${item.name}`) }))} className='mt-0 !w-40' @@ -68,7 +70,7 @@ const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilte { @@ -76,6 +78,22 @@ const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilte }} />
+ {isChatMode && ( + <> +
+ { + setQueryParams({ ...queryParams, sort_by: value as string }) + }} + /> + + )}
) } diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index 828004d6f3..dd6ebd08f0 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -24,6 +24,7 @@ export type QueryParam = { period?: number | string annotation_status?: string keyword?: string + sort_by?: string } const ThreeDotsIcon = ({ className }: SVGProps) => { @@ -52,18 +53,26 @@ const EmptyElement: FC<{ appUrl: string }> = ({ appUrl }) => { const Logs: FC = ({ appDetail }) => { const { t } = useTranslation() - const [queryParams, setQueryParams] = useState({ period: 7, annotation_status: 'all' }) + const [queryParams, setQueryParams] = useState({ + period: 7, + annotation_status: 'all', + sort_by: '-created_at', + }) const [currPage, setCurrPage] = React.useState(0) + // Get the app type first + const isChatMode = appDetail.mode !== 'completion' + const query = { page: currPage + 1, limit: APP_PAGE_LIMIT, ...(queryParams.period !== 'all' ? { start: dayjs().subtract(queryParams.period as number, 'day').startOf('day').format('YYYY-MM-DD HH:mm'), - end: dayjs().format('YYYY-MM-DD HH:mm'), + end: dayjs().endOf('day').format('YYYY-MM-DD HH:mm'), } : {}), + ...(isChatMode ? { sort_by: queryParams.sort_by } : {}), ...omit(queryParams, ['period']), } @@ -73,9 +82,6 @@ const Logs: FC = ({ appDetail }) => { return appType } - // Get the app type first - const isChatMode = appDetail.mode !== 'completion' - // When the details are obtained, proceed to the next request const { data: chatConversations, mutate: mutateChatList } = useSWR(() => isChatMode ? { @@ -97,7 +103,7 @@ const Logs: FC = ({ appDetail }) => {

{t('appLog.description')}

- + {total === undefined ? : total > 0 diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 646ae80116..40d171761b 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -671,12 +671,13 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) - {t('appLog.table.header.time')} - {t('appLog.table.header.endUser')} {isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')} + {t('appLog.table.header.endUser')} {isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')} {t('appLog.table.header.userRate')} {t('appLog.table.header.adminRate')} + {t('appLog.table.header.updatedTime')} + {t('appLog.table.header.time')} @@ -692,11 +693,10 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) setCurrentConversation(log) }}> {!log.read_at && } - {formatTime(log.created_at, t('appLog.dateTimeFormat') as string)} - {renderTdValue(endUser || defaultValue, !endUser)} {renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)} + {renderTdValue(endUser || defaultValue, !endUser)} {renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)} @@ -718,6 +718,8 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) } + {formatTime(log.updated_at, t('appLog.dateTimeFormat') as string)} + {formatTime(log.created_at, t('appLog.dateTimeFormat') as string)} })} diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 67441ec0ff..b9072e3983 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -12,11 +12,11 @@ import Textarea from '@/app/components/base/textarea' import AppIcon from '@/app/components/base/app-icon' import { SimpleSelect } from '@/app/components/base/select' import type { AppDetailResponse } from '@/models/app' -import type { Language } from '@/types/app' -import EmojiPicker from '@/app/components/base/emoji-picker' +import type { AppIconType, Language } from '@/types/app' import { useToastContext } from '@/app/components/base/toast' - import { languages } from '@/i18n/language' +import type { AppIconSelection } from '@/app/components/base/app-icon-picker' +import AppIconPicker from '@/app/components/base/app-icon-picker' export type ISettingsModalProps = { isChat: boolean @@ -37,8 +37,9 @@ export type ConfigParams = { copyright: string privacy_policy: string custom_disclaimer: string + icon_type: AppIconType icon: string - icon_background: string + icon_background?: string show_workflow_steps: boolean } @@ -53,9 +54,12 @@ const SettingsModal: FC = ({ }) => { const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) - const { icon, icon_background } = appInfo const { title, + icon_type, + icon, + icon_background, + icon_url, description, chat_color_theme, chat_color_theme_inverted, @@ -78,9 +82,13 @@ const SettingsModal: FC = ({ const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) const { t } = useTranslation() - // Emoji Picker - const [showEmojiPicker, setShowEmojiPicker] = useState(false) - const [emoji, setEmoji] = useState({ icon, icon_background }) + + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [appIcon, setAppIcon] = useState( + icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }, + ) useEffect(() => { setInputInfo({ @@ -94,7 +102,9 @@ const SettingsModal: FC = ({ show_workflow_steps, }) setLanguage(default_language) - setEmoji({ icon, icon_background }) + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) }, [appInfo]) const onHide = () => { @@ -137,8 +147,9 @@ const SettingsModal: FC = ({ copyright: inputInfo.copyright, privacy_policy: inputInfo.privacyPolicy, custom_disclaimer: inputInfo.customDisclaimer, - icon: emoji.icon, - icon_background: emoji.icon_background, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, show_workflow_steps: inputInfo.show_workflow_steps, } await onSave?.(params) @@ -173,10 +184,12 @@ const SettingsModal: FC = ({
{t(`${prefixSettings}.webName`)}
{ setShowEmojiPicker(true) }} + onClick={() => { setShowAppIconPicker(true) }} className='cursor-pointer !mr-3 self-center' - icon={emoji.icon} - background={emoji.icon_background} + iconType={appIcon.type} + icon={appIcon.type === 'image' ? appIcon.fileId : appIcon.icon} + background={appIcon.type === 'image' ? undefined : appIcon.background} + imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} /> = ({
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon: appInfo.site.icon, icon_background: appInfo.site.icon_background }) - setShowEmojiPicker(false) + setAppIcon(icon_type === 'image' + ? { type: 'image', url: icon_url!, fileId: icon } + : { type: 'emoji', icon, background: icon_background! }) + setShowAppIconPicker(false) }} />} diff --git a/web/app/components/app/switch-app-modal/index.tsx b/web/app/components/app/switch-app-modal/index.tsx index 366cfe7a4b..5b45095251 100644 --- a/web/app/components/app/switch-app-modal/index.tsx +++ b/web/app/components/app/switch-app-modal/index.tsx @@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' +import AppIconPicker from '../../base/app-icon-picker' import s from './style.module.css' import cn from '@/utils/classnames' import Checkbox from '@/app/components/base/checkbox' @@ -17,7 +18,6 @@ import { deleteApp, switchApp } from '@/service/apps' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' -import EmojiPicker from '@/app/components/base/emoji-picker' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { getRedirection } from '@/utils/app-redirection' import type { App } from '@/types/app' @@ -43,8 +43,13 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo const { plan, enableBilling } = useProviderContext() const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) - const [emoji, setEmoji] = useState({ icon: appDetail.icon, icon_background: appDetail.icon_background }) - const [showEmojiPicker, setShowEmojiPicker] = useState(false) + const [showAppIconPicker, setShowAppIconPicker] = useState(false) + const [appIcon, setAppIcon] = useState( + appDetail.icon_type === 'image' + ? { type: 'image' as const, url: appDetail.icon_url, fileId: appDetail.icon } + : { type: 'emoji' as const, icon: appDetail.icon, background: appDetail.icon_background }, + ) + const [name, setName] = useState(`${appDetail.name}(copy)`) const [removeOriginal, setRemoveOriginal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) @@ -54,8 +59,9 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo const { new_app_id: newAppID } = await switchApp({ appID: appDetail.id, name, - icon: emoji.icon, - icon_background: emoji.icon_background, + icon_type: appIcon.type, + icon: appIcon.type === 'emoji' ? appIcon.icon : appIcon.fileId, + icon_background: appIcon.type === 'emoji' ? appIcon.background : undefined, }) if (onSuccess) onSuccess() @@ -108,7 +114,15 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo
{t('app.switchLabel')}
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.icon} background={emoji.icon_background} /> + { setShowAppIconPicker(true) }} + className='cursor-pointer' + iconType={appIcon.type} + icon={appIcon.type === 'image' ? appIcon.fileId : appIcon.icon} + background={appIcon.type === 'image' ? undefined : appIcon.background} + imageUrl={appIcon.type === 'image' ? appIcon.url : undefined} + /> setName(e.target.value)} @@ -116,14 +130,16 @@ const SwitchAppModal = ({ show, appDetail, inAppDetail = false, onSuccess, onClo className='grow h-10' />
- {showEmojiPicker && { - setEmoji({ icon, icon_background }) - setShowEmojiPicker(false) + {showAppIconPicker && { + setAppIcon(payload) + setShowAppIconPicker(false) }} onClose={() => { - setEmoji({ icon: appDetail.icon, icon_background: appDetail.icon_background }) - setShowEmojiPicker(false) + setAppIcon(appDetail.icon_type === 'image' + ? { type: 'image' as const, url: appDetail.icon_url, fileId: appDetail.icon } + : { type: 'emoji' as const, icon: appDetail.icon, background: appDetail.icon_background }) + setShowAppIconPicker(false) }} />}
diff --git a/web/app/components/base/app-icon-picker/Uploader.tsx b/web/app/components/base/app-icon-picker/Uploader.tsx new file mode 100644 index 0000000000..4ddaa40447 --- /dev/null +++ b/web/app/components/base/app-icon-picker/Uploader.tsx @@ -0,0 +1,97 @@ +'use client' + +import type { ChangeEvent, FC } from 'react' +import { createRef, useEffect, useState } from 'react' +import type { Area } from 'react-easy-crop' +import Cropper from 'react-easy-crop' +import classNames from 'classnames' + +import { ImagePlus } from '../icons/src/vender/line/images' +import { useDraggableUploader } from './hooks' +import { ALLOW_FILE_EXTENSIONS } from '@/types/app' + +type UploaderProps = { + className?: string + onImageCropped?: (tempUrl: string, croppedAreaPixels: Area, fileName: string) => void +} + +const Uploader: FC = ({ + className, + onImageCropped, +}) => { + const [inputImage, setInputImage] = useState<{ file: File; url: string }>() + useEffect(() => { + return () => { + if (inputImage) + URL.revokeObjectURL(inputImage.url) + } + }, [inputImage]) + + const [crop, setCrop] = useState({ x: 0, y: 0 }) + const [zoom, setZoom] = useState(1) + + const onCropComplete = async (_: Area, croppedAreaPixels: Area) => { + if (!inputImage) + return + onImageCropped?.(inputImage.url, croppedAreaPixels, inputImage.file.name) + } + + const handleLocalFileInput = (e: ChangeEvent) => { + const file = e.target.files?.[0] + if (file) + setInputImage({ file, url: URL.createObjectURL(file) }) + } + + const { + isDragActive, + handleDragEnter, + handleDragOver, + handleDragLeave, + handleDrop, + } = useDraggableUploader((file: File) => setInputImage({ file, url: URL.createObjectURL(file) })) + + const inputRef = createRef() + + return ( +
+
+ { + !inputImage + ? <> + +
+ Drop your image here, or  + + ((e.target as HTMLInputElement).value = '')} + accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')} + onChange={handleLocalFileInput} + /> +
+
Supports PNG, JPG, JPEG, WEBP and GIF
+ + : + } +
+
+ ) +} + +export default Uploader diff --git a/web/app/components/base/app-icon-picker/hooks.tsx b/web/app/components/base/app-icon-picker/hooks.tsx new file mode 100644 index 0000000000..b3f67c0dca --- /dev/null +++ b/web/app/components/base/app-icon-picker/hooks.tsx @@ -0,0 +1,43 @@ +import { useCallback, useState } from 'react' + +export const useDraggableUploader = (setImageFn: (file: File) => void) => { + const [isDragActive, setIsDragActive] = useState(false) + + const handleDragEnter = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragActive(true) + }, []) + + const handleDragOver = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + }, []) + + const handleDragLeave = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragActive(false) + }, []) + + const handleDrop = useCallback((e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragActive(false) + + const file = e.dataTransfer.files[0] + + if (!file) + return + + setImageFn(file) + }, [setImageFn]) + + return { + handleDragEnter, + handleDragOver, + handleDragLeave, + handleDrop, + isDragActive, + } +} diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx new file mode 100644 index 0000000000..8254815475 --- /dev/null +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -0,0 +1,139 @@ +import type { FC } from 'react' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import type { Area } from 'react-easy-crop' +import Modal from '../modal' +import Divider from '../divider' +import Button from '../button' +import { ImagePlus } from '../icons/src/vender/line/images' +import { useLocalFileUploader } from '../image-uploader/hooks' +import EmojiPickerInner from '../emoji-picker/Inner' +import Uploader from './Uploader' +import s from './style.module.css' +import getCroppedImg from './utils' +import type { AppIconType, ImageFile } from '@/types/app' +import cn from '@/utils/classnames' +import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' +export type AppIconEmojiSelection = { + type: 'emoji' + icon: string + background: string +} + +export type AppIconImageSelection = { + type: 'image' + fileId: string + url: string +} + +export type AppIconSelection = AppIconEmojiSelection | AppIconImageSelection + +type AppIconPickerProps = { + onSelect?: (payload: AppIconSelection) => void + onClose?: () => void + className?: string +} + +const AppIconPicker: FC = ({ + onSelect, + onClose, + className, +}) => { + const { t } = useTranslation() + + const tabs = [ + { key: 'emoji', label: t('app.iconPicker.emoji'), icon: 🤖 }, + { key: 'image', label: t('app.iconPicker.image'), icon: }, + ] + const [activeTab, setActiveTab] = useState('emoji') + + const [emoji, setEmoji] = useState<{ emoji: string; background: string }>() + const handleSelectEmoji = useCallback((emoji: string, background: string) => { + setEmoji({ emoji, background }) + }, [setEmoji]) + + const [uploading, setUploading] = useState() + + const { handleLocalFileUpload } = useLocalFileUploader({ + limit: 3, + disabled: false, + onUpload: (imageFile: ImageFile) => { + if (imageFile.fileId) { + setUploading(false) + onSelect?.({ + type: 'image', + fileId: imageFile.fileId, + url: imageFile.url, + }) + } + }, + }) + + const [imageCropInfo, setImageCropInfo] = useState<{ tempUrl: string; croppedAreaPixels: Area; fileName: string }>() + const handleImageCropped = async (tempUrl: string, croppedAreaPixels: Area, fileName: string) => { + setImageCropInfo({ tempUrl, croppedAreaPixels, fileName }) + } + + const handleSelect = async () => { + if (activeTab === 'emoji') { + if (emoji) { + onSelect?.({ + type: 'emoji', + icon: emoji.emoji, + background: emoji.background, + }) + } + } + else { + if (!imageCropInfo) + return + setUploading(true) + const blob = await getCroppedImg(imageCropInfo.tempUrl, imageCropInfo.croppedAreaPixels) + const file = new File([blob], imageCropInfo.fileName, { type: blob.type }) + handleLocalFileUpload(file) + } + } + + return { }} + isShow + closable={false} + wrapperClassName={className} + className={cn(s.container, '!w-[362px] !p-0')} + > + {!DISABLE_UPLOAD_IMAGE_AS_ICON &&
+
+ {tabs.map(tab => ( + + ))} +
+
} + + + + + + + +
+ + + +
+
+} + +export default AppIconPicker diff --git a/web/app/components/base/app-icon-picker/style.module.css b/web/app/components/base/app-icon-picker/style.module.css new file mode 100644 index 0000000000..5facb3560a --- /dev/null +++ b/web/app/components/base/app-icon-picker/style.module.css @@ -0,0 +1,12 @@ +.container { + display: flex; + flex-direction: column; + align-items: flex-start; + width: 362px; + max-height: 552px; + + border: 0.5px solid #EAECF0; + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); + border-radius: 12px; + background: #fff; +} diff --git a/web/app/components/base/app-icon-picker/utils.ts b/web/app/components/base/app-icon-picker/utils.ts new file mode 100644 index 0000000000..0c90e96feb --- /dev/null +++ b/web/app/components/base/app-icon-picker/utils.ts @@ -0,0 +1,98 @@ +export const createImage = (url: string) => + new Promise((resolve, reject) => { + const image = new Image() + image.addEventListener('load', () => resolve(image)) + image.addEventListener('error', error => reject(error)) + image.setAttribute('crossOrigin', 'anonymous') // needed to avoid cross-origin issues on CodeSandbox + image.src = url + }) + +export function getRadianAngle(degreeValue: number) { + return (degreeValue * Math.PI) / 180 +} + +/** + * Returns the new bounding area of a rotated rectangle. + */ +export function rotateSize(width: number, height: number, rotation: number) { + const rotRad = getRadianAngle(rotation) + + return { + width: + Math.abs(Math.cos(rotRad) * width) + Math.abs(Math.sin(rotRad) * height), + height: + Math.abs(Math.sin(rotRad) * width) + Math.abs(Math.cos(rotRad) * height), + } +} + +/** + * This function was adapted from the one in the ReadMe of https://github.com/DominicTobias/react-image-crop + */ +export default async function getCroppedImg( + imageSrc: string, + pixelCrop: { x: number; y: number; width: number; height: number }, + rotation = 0, + flip = { horizontal: false, vertical: false }, +): Promise { + const image = await createImage(imageSrc) + const canvas = document.createElement('canvas') + const ctx = canvas.getContext('2d') + + if (!ctx) + throw new Error('Could not create a canvas context') + + const rotRad = getRadianAngle(rotation) + + // calculate bounding box of the rotated image + const { width: bBoxWidth, height: bBoxHeight } = rotateSize( + image.width, + image.height, + rotation, + ) + + // set canvas size to match the bounding box + canvas.width = bBoxWidth + canvas.height = bBoxHeight + + // translate canvas context to a central location to allow rotating and flipping around the center + ctx.translate(bBoxWidth / 2, bBoxHeight / 2) + ctx.rotate(rotRad) + ctx.scale(flip.horizontal ? -1 : 1, flip.vertical ? -1 : 1) + ctx.translate(-image.width / 2, -image.height / 2) + + // draw rotated image + ctx.drawImage(image, 0, 0) + + const croppedCanvas = document.createElement('canvas') + + const croppedCtx = croppedCanvas.getContext('2d') + + if (!croppedCtx) + throw new Error('Could not create a canvas context') + + // Set the size of the cropped canvas + croppedCanvas.width = pixelCrop.width + croppedCanvas.height = pixelCrop.height + + // Draw the cropped image onto the new canvas + croppedCtx.drawImage( + canvas, + pixelCrop.x, + pixelCrop.y, + pixelCrop.width, + pixelCrop.height, + 0, + 0, + pixelCrop.width, + pixelCrop.height, + ) + + return new Promise((resolve, reject) => { + croppedCanvas.toBlob((file) => { + if (file) + resolve(file) + else + reject(new Error('Could not create a blob')) + }, 'image/jpeg') + }) +} diff --git a/web/app/components/base/app-icon/index.tsx b/web/app/components/base/app-icon/index.tsx index 9d8cf28bed..5e7378c087 100644 --- a/web/app/components/base/app-icon/index.tsx +++ b/web/app/components/base/app-icon/index.tsx @@ -1,17 +1,21 @@ -import type { FC } from 'react' +'use client' -import data from '@emoji-mart/data' +import type { FC } from 'react' import { init } from 'emoji-mart' +import data from '@emoji-mart/data' import style from './style.module.css' import classNames from '@/utils/classnames' +import type { AppIconType } from '@/types/app' init({ data }) export type AppIconProps = { size?: 'xs' | 'tiny' | 'small' | 'medium' | 'large' rounded?: boolean + iconType?: AppIconType | null icon?: string - background?: string + background?: string | null + imageUrl?: string | null className?: string innerIcon?: React.ReactNode onClick?: () => void @@ -20,28 +24,34 @@ export type AppIconProps = { const AppIcon: FC = ({ size = 'medium', rounded = false, + iconType, icon, background, + imageUrl, className, innerIcon, onClick, }) => { - return ( - - {innerIcon || ((icon && icon !== '') ? : )} - + const wrapperClassName = classNames( + style.appIcon, + size !== 'medium' && style[size], + rounded && style.rounded, + className ?? '', + 'overflow-hidden', ) + + const isValidImageIcon = iconType === 'image' && imageUrl + + return + {isValidImageIcon + ? app icon + : (innerIcon || ((icon && icon !== '') ? : )) + } + } export default AppIcon diff --git a/web/app/components/base/app-icon/style.module.css b/web/app/components/base/app-icon/style.module.css index 20f76d4099..06a2478d41 100644 --- a/web/app/components/base/app-icon/style.module.css +++ b/web/app/components/base/app-icon/style.module.css @@ -1,5 +1,5 @@ .appIcon { - @apply flex items-center justify-center relative w-9 h-9 text-lg bg-teal-100 rounded-lg grow-0 shrink-0; + @apply flex items-center justify-center relative w-9 h-9 text-lg rounded-lg grow-0 shrink-0; } .appIcon.large { diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 3fa301d268..624cc53a18 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -43,7 +43,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo]) const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR(installedAppInfo ? null : 'appInfo', fetchAppInfo) - useAppFavicon(!installedAppInfo, appInfo?.site.icon, appInfo?.site.icon_background) + useAppFavicon({ + enable: !installedAppInfo, + icon_type: appInfo?.site.icon_type, + icon: appInfo?.site.icon, + icon_background: appInfo?.site.icon_background, + icon_url: appInfo?.site.icon_url, + }) const appData = useMemo(() => { if (isInstalledApp) { @@ -52,8 +58,10 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { app_id: id, site: { title: app.name, + icon_type: app.icon_type, icon: app.icon, icon_background: app.icon_background, + icon_url: app.icon_url, prompt_public: false, copyright: '', show_workflow_steps: true, diff --git a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx index dd145c36b9..17a2751f11 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx @@ -67,8 +67,10 @@ const Sidebar = () => {
{appData?.site.title} diff --git a/web/app/components/base/chat/chat/chat-input.tsx b/web/app/components/base/chat/chat/chat-input.tsx index 6d15010f85..0c083157a1 100644 --- a/web/app/components/base/chat/chat/chat-input.tsx +++ b/web/app/components/base/chat/chat/chat-input.tsx @@ -49,6 +49,7 @@ const ChatInput: FC = ({ const { t } = useTranslation() const { notify } = useContext(ToastContext) const [voiceInputShow, setVoiceInputShow] = useState(false) + const textAreaRef = useRef(null) const { files, onUpload, @@ -89,7 +90,7 @@ const ChatInput: FC = ({ } const handleKeyUp = (e: React.KeyboardEvent) => { - if (e.code === 'Enter') { + if (e.key === 'Enter') { e.preventDefault() // prevent send message when using input method enter if (!e.shiftKey && !isUseInputMethod.current) @@ -99,7 +100,7 @@ const ChatInput: FC = ({ const handleKeyDown = (e: React.KeyboardEvent) => { isUseInputMethod.current = e.nativeEvent.isComposing - if (e.code === 'Enter' && !e.shiftKey) { + if (e.key === 'Enter' && !e.shiftKey) { setQuery(query.replace(/\n$/, '')) e.preventDefault() } @@ -176,6 +177,7 @@ const ChatInput: FC = ({ ) }