diff --git a/api/commands.py b/api/commands.py index 76c8d3e382..1b1abcd9c3 100644 --- a/api/commands.py +++ b/api/commands.py @@ -677,3 +677,42 @@ def extract_plugins(output_file: str, workers: int): PluginMigration.extract_plugins(output_file, workers) click.echo(click.style("Extract plugins completed.", fg="green")) + + +@click.command("extract-unique-identifiers", help="Extract unique identifiers.") +@click.option( + "--output_file", + prompt=True, + help="The file to store the extracted unique identifiers.", + default="unique_identifiers.json", +) +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +def extract_unique_plugins(output_file: str, input_file: str): + """ + Extract unique plugins. + """ + click.echo(click.style("Starting extract unique plugins.", fg="white")) + + PluginMigration.extract_unique_plugins(input_file, output_file) + + click.echo(click.style("Extract unique plugins completed.", fg="green")) + + +@click.command("install-plugins", help="Install plugins.") +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +@click.option( + "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" +) +def install_plugins(input_file: str, output_file: str): + """ + Install plugins. + """ + click.echo(click.style("Starting install plugins.", fg="white")) + + PluginMigration.install_plugins(input_file, output_file) + + click.echo(click.style("Install plugins completed.", fg="green")) diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index c793346829..85321bed94 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from enum import Enum from typing import Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -134,6 +134,14 @@ class ProviderEntity(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + @field_validator("models", mode="before") + @classmethod + def validate_models(cls, v): + # returns EmptyList if v is empty + if not v: + return [] + return v + def to_simple_provider(self) -> SimpleProviderEntity: """ Convert to simple provider. diff --git a/api/core/plugin/manager/plugin.py b/api/core/plugin/manager/plugin.py index c96e6c621b..4f5970d3b9 100644 --- a/api/core/plugin/manager/plugin.py +++ b/api/core/plugin/manager/plugin.py @@ -76,7 +76,11 @@ class PluginInstallationManager(BasePluginManager): ) def install_from_identifiers( - self, tenant_id: str, identifiers: Sequence[str], source: PluginInstallationSource, meta: dict + self, + tenant_id: str, + identifiers: Sequence[str], + source: PluginInstallationSource, + metas: list[dict], ) -> PluginInstallTaskStartResponse: """ Install a plugin from an identifier. @@ -89,7 +93,7 @@ class PluginInstallationManager(BasePluginManager): data={ "plugin_unique_identifiers": identifiers, "source": source, - "meta": meta, + "metas": metas, }, headers={"Content-Type": "application/json"}, ) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 353f9144b8..0418361094 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -7,6 +7,7 @@ def init_app(app: DifyApp): convert_to_agent_apps, create_tenant, extract_plugins, + extract_unique_plugins, fix_app_site_missing, migrate_data_for_plugin, reset_email, @@ -14,6 +15,7 @@ def init_app(app: DifyApp): reset_password, upgrade_db, vdb_migrate, + install_plugins, ) cmds_to_register = [ @@ -28,6 +30,8 @@ def init_app(app: DifyApp): fix_app_site_missing, migrate_data_for_plugin, extract_plugins, + extract_unique_plugins, + install_plugins, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 6c3e1ef9f6..f7f67d8ebb 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -1,14 +1,25 @@ +from concurrent.futures import ThreadPoolExecutor import datetime import json import logging from collections.abc import Sequence +from pathlib import Path +import sys +import time +from typing import Any, Mapping, Optional +from uuid import uuid4 import click +import tqdm from flask import Flask, current_app from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.entities import DEFAULT_PLUGIN_ID +from core.helper import marketplace +from core.plugin.entities.plugin import PluginInstallationSource +from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus +from core.plugin.manager.plugin import PluginInstallationManager from core.tools.entities.tool_entities import ToolProviderType from models.account import Tenant from models.engine import db @@ -199,7 +210,7 @@ class PluginMigration: if provider_name == "google": provider_name = "gemini" - result.append(DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name) + result.append(DEFAULT_PLUGIN_ID + "/" + provider_name) elif provider_name: result.append(provider_name) @@ -215,7 +226,7 @@ class PluginMigration: result = [] for row in rs: if "/" not in row.provider: - result.append(DEFAULT_PLUGIN_ID + "/" + row.provider + "/" + row.provider) + result.append(DEFAULT_PLUGIN_ID + "/" + row.provider) else: result.append(row.provider) @@ -234,7 +245,7 @@ class PluginMigration: provider_name = "stepfun_tool" if "/" not in provider_name: - return DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name + return DEFAULT_PLUGIN_ID + "/" + provider_name else: return provider_name @@ -297,3 +308,216 @@ class PluginMigration: continue return result + + @classmethod + def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]: + """ + Fetch plugin unique identifier using plugin id. + """ + plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id]) + if not plugin_manifest: + return None + + return plugin_manifest[0].latest_package_identifier + + @classmethod + def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None: + """ + Extract unique plugins. + """ + Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins))) + + @classmethod + def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]: + plugins: dict[str, str] = {} + plugin_ids = [] + plugin_not_exist = [] + logger.info(f"Extracting unique plugins from {extracted_plugins}") + with open(extracted_plugins) as f: + for line in f: + data = json.loads(line) + new_plugin_ids = data.get("plugins", []) + for plugin_id in new_plugin_ids: + if plugin_id not in plugin_ids: + plugin_ids.append(plugin_id) + + def fetch_plugin(plugin_id): + unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id) + if unique_identifier: + plugins[plugin_id] = unique_identifier + else: + plugin_not_exist.append(plugin_id) + + with ThreadPoolExecutor(max_workers=10) as executor: + list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids))) + + return {"plugins": plugins, "plugin_not_exist": plugin_not_exist} + + @classmethod + def install_plugins(cls, extracted_plugins: str, output_file: str) -> None: + """ + Install plugins. + """ + manager = PluginInstallationManager() + + plugins = cls.extract_unique_plugins(extracted_plugins) + not_installed = [] + plugin_install_failed = [] + + # use a fake tenant id to install all the plugins + fake_tenant_id = uuid4().hex + logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}") + + thread_pool = ThreadPoolExecutor(max_workers=40) + + response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"]) + if response.get("failed"): + plugin_install_failed.extend(response.get("failed", [])) + + def install(tenant_id: str, plugin_ids: list[str]) -> None: + logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}") + # at most 64 plugins one batch + for i in range(0, len(plugin_ids), 64): + batch_plugin_ids = plugin_ids[i : i + 64] + batch_plugin_identifiers = [plugins["plugins"][plugin_id] for plugin_id in batch_plugin_ids] + manager.install_from_identifiers( + tenant_id, + batch_plugin_identifiers, + PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": identifier, + } + for identifier in batch_plugin_identifiers + ], + ) + + with open(extracted_plugins, "r") as f: + """ + Read line by line, and install plugins for each tenant. + """ + for line in f: + data = json.loads(line) + tenant_id = data.get("tenant_id") + plugin_ids = data.get("plugins", []) + current_not_installed = { + "tenant_id": tenant_id, + "plugin_not_exist": [], + } + # get plugin unique identifier + for plugin_id in plugin_ids: + unique_identifier = plugins.get(plugin_id) + if unique_identifier: + current_not_installed["plugin_not_exist"].append(plugin_id) + + if current_not_installed["plugin_not_exist"]: + not_installed.append(current_not_installed) + + thread_pool.submit(install, tenant_id, plugin_ids) + + thread_pool.shutdown(wait=True) + + logger.info("Uninstall plugins") + + sys.exit(-1) + + # get installation + try: + installation = manager.list_plugins(fake_tenant_id) + while installation: + for plugin in installation: + manager.uninstall(fake_tenant_id, plugin.installation_id) + + installation = manager.list_plugins(fake_tenant_id) + except Exception: + logger.exception(f"Failed to get installation for tenant {fake_tenant_id}") + + Path(output_file).write_text( + json.dumps( + { + "not_installed": not_installed, + "plugin_install_failed": plugin_install_failed, + } + ) + ) + + @classmethod + def handle_plugin_instance_install( + cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str] + ) -> Mapping[str, Any]: + """ + Install plugins for a tenant. + """ + manager = PluginInstallationManager() + + # download all the plugins and upload + thread_pool = ThreadPoolExecutor(max_workers=10) + futures = [] + + for plugin_id, plugin_identifier in plugin_identifiers_map.items(): + + def download_and_upload(tenant_id, plugin_id, plugin_identifier): + plugin_package = marketplace.download_plugin_pkg(plugin_identifier) + if not plugin_package: + raise Exception(f"Failed to download plugin {plugin_identifier}") + + # upload + manager.upload_pkg(tenant_id, plugin_package, verify_signature=True) + + futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier)) + + # Wait for all downloads to complete + for future in futures: + future.result() # This will raise any exceptions that occurred + + thread_pool.shutdown(wait=True) + success = [] + failed = [] + + reverse_map = {v: k for k, v in plugin_identifiers_map.items()} + + # at most 64 plugins one batch + for i in range(0, len(plugin_identifiers_map), 64): + batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 64] + batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids] + + try: + response = manager.install_from_identifiers( + tenant_id=tenant_id, + identifiers=batch_plugin_identifiers, + source=PluginInstallationSource.Marketplace, + metas=[ + { + "plugin_unique_identifier": identifier, + } + for identifier in batch_plugin_identifiers + ], + ) + except Exception: + # add to failed + failed.extend(batch_plugin_identifiers) + continue + + if response.all_installed: + success.extend(batch_plugin_identifiers) + continue + + task_id = response.task_id + done = False + while not done: + status = manager.fetch_plugin_installation_task(tenant_id, task_id) + if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]: + for plugin in status.plugins: + if plugin.status == PluginInstallTaskStatus.Success: + success.append(reverse_map[plugin.plugin_unique_identifier]) + else: + failed.append(reverse_map[plugin.plugin_unique_identifier]) + logger.error( + f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}" + ) + + done = True + else: + time.sleep(1) + + return {"success": success, "failed": failed} diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index e3c37ecba7..f84baf6b81 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -232,7 +232,7 @@ class PluginService: tenant_id, plugin_unique_identifiers, PluginInstallationSource.Package, - {}, + [{}], ) @staticmethod @@ -246,11 +246,13 @@ class PluginService: tenant_id, [plugin_unique_identifier], PluginInstallationSource.Github, - { - "repo": repo, - "version": version, - "package": package, - }, + [ + { + "repo": repo, + "version": version, + "package": package, + } + ], ) @staticmethod @@ -277,9 +279,12 @@ class PluginService: tenant_id, plugin_unique_identifiers, PluginInstallationSource.Marketplace, - { - "plugin_unique_identifier": plugin_unique_identifier, - }, + [ + { + "plugin_unique_identifier": plugin_unique_identifier, + } + for plugin_unique_identifier in plugin_unique_identifiers + ], ) @staticmethod