feat: plugin migrations

This commit is contained in:
Yeuoly 2024-12-31 16:38:02 +08:00
parent 06412b37d3
commit 6e73ad2fc6
6 changed files with 299 additions and 15 deletions

View File

@ -677,3 +677,42 @@ def extract_plugins(output_file: str, workers: int):
PluginMigration.extract_plugins(output_file, workers)
click.echo(click.style("Extract plugins completed.", fg="green"))
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
@click.option(
"--output_file",
prompt=True,
help="The file to store the extracted unique identifiers.",
default="unique_identifiers.json",
)
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
def extract_unique_plugins(output_file: str, input_file: str):
"""
Extract unique plugins.
"""
click.echo(click.style("Starting extract unique plugins.", fg="white"))
PluginMigration.extract_unique_plugins(input_file, output_file)
click.echo(click.style("Extract unique plugins completed.", fg="green"))
@click.command("install-plugins", help="Install plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
def install_plugins(input_file: str, output_file: str):
"""
Install plugins.
"""
click.echo(click.style("Starting install plugins.", fg="white"))
PluginMigration.install_plugins(input_file, output_file)
click.echo(click.style("Install plugins completed.", fg="green"))

View File

@ -2,7 +2,7 @@ from collections.abc import Sequence
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -134,6 +134,14 @@ class ProviderEntity(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("models", mode="before")
@classmethod
def validate_models(cls, v):
# returns EmptyList if v is empty
if not v:
return []
return v
def to_simple_provider(self) -> SimpleProviderEntity:
"""
Convert to simple provider.

View File

@ -76,7 +76,11 @@ class PluginInstallationManager(BasePluginManager):
)
def install_from_identifiers(
self, tenant_id: str, identifiers: Sequence[str], source: PluginInstallationSource, meta: dict
self,
tenant_id: str,
identifiers: Sequence[str],
source: PluginInstallationSource,
metas: list[dict],
) -> PluginInstallTaskStartResponse:
"""
Install a plugin from an identifier.
@ -89,7 +93,7 @@ class PluginInstallationManager(BasePluginManager):
data={
"plugin_unique_identifiers": identifiers,
"source": source,
"meta": meta,
"metas": metas,
},
headers={"Content-Type": "application/json"},
)

View File

@ -7,6 +7,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
extract_plugins,
extract_unique_plugins,
fix_app_site_missing,
migrate_data_for_plugin,
reset_email,
@ -14,6 +15,7 @@ def init_app(app: DifyApp):
reset_password,
upgrade_db,
vdb_migrate,
install_plugins,
)
cmds_to_register = [
@ -28,6 +30,8 @@ def init_app(app: DifyApp):
fix_app_site_missing,
migrate_data_for_plugin,
extract_plugins,
extract_unique_plugins,
install_plugins,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -1,14 +1,25 @@
from concurrent.futures import ThreadPoolExecutor
import datetime
import json
import logging
from collections.abc import Sequence
from pathlib import Path
import sys
import time
from typing import Any, Mapping, Optional
from uuid import uuid4
import click
import tqdm
from flask import Flask, current_app
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.helper import marketplace
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolProviderType
from models.account import Tenant
from models.engine import db
@ -199,7 +210,7 @@ class PluginMigration:
if provider_name == "google":
provider_name = "gemini"
result.append(DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name)
result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
elif provider_name:
result.append(provider_name)
@ -215,7 +226,7 @@ class PluginMigration:
result = []
for row in rs:
if "/" not in row.provider:
result.append(DEFAULT_PLUGIN_ID + "/" + row.provider + "/" + row.provider)
result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
else:
result.append(row.provider)
@ -234,7 +245,7 @@ class PluginMigration:
provider_name = "stepfun_tool"
if "/" not in provider_name:
return DEFAULT_PLUGIN_ID + "/" + provider_name + "/" + provider_name
return DEFAULT_PLUGIN_ID + "/" + provider_name
else:
return provider_name
@ -297,3 +308,216 @@ class PluginMigration:
continue
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
"""
Fetch plugin unique identifier using plugin id.
"""
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
if not plugin_manifest:
return None
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
"""
Extract unique plugins.
"""
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
@classmethod
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
plugins: dict[str, str] = {}
plugin_ids = []
plugin_not_exist = []
logger.info(f"Extracting unique plugins from {extracted_plugins}")
with open(extracted_plugins) as f:
for line in f:
data = json.loads(line)
new_plugin_ids = data.get("plugins", [])
for plugin_id in new_plugin_ids:
if plugin_id not in plugin_ids:
plugin_ids.append(plugin_id)
def fetch_plugin(plugin_id):
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
if unique_identifier:
plugins[plugin_id] = unique_identifier
else:
plugin_not_exist.append(plugin_id)
with ThreadPoolExecutor(max_workers=10) as executor:
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str) -> None:
"""
Install plugins.
"""
manager = PluginInstallationManager()
plugins = cls.extract_unique_plugins(extracted_plugins)
not_installed = []
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
thread_pool = ThreadPoolExecutor(max_workers=40)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = plugin_ids[i : i + 64]
batch_plugin_identifiers = [plugins["plugins"][plugin_id] for plugin_id in batch_plugin_ids]
manager.install_from_identifiers(
tenant_id,
batch_plugin_identifiers,
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
with open(extracted_plugins, "r") as f:
"""
Read line by line, and install plugins for each tenant.
"""
for line in f:
data = json.loads(line)
tenant_id = data.get("tenant_id")
plugin_ids = data.get("plugins", [])
current_not_installed = {
"tenant_id": tenant_id,
"plugin_not_exist": [],
}
# get plugin unique identifier
for plugin_id in plugin_ids:
unique_identifier = plugins.get(plugin_id)
if unique_identifier:
current_not_installed["plugin_not_exist"].append(plugin_id)
if current_not_installed["plugin_not_exist"]:
not_installed.append(current_not_installed)
thread_pool.submit(install, tenant_id, plugin_ids)
thread_pool.shutdown(wait=True)
logger.info("Uninstall plugins")
sys.exit(-1)
# get installation
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
Path(output_file).write_text(
json.dumps(
{
"not_installed": not_installed,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
) -> Mapping[str, Any]:
"""
Install plugins for a tenant.
"""
manager = PluginInstallationManager()
# download all the plugins and upload
thread_pool = ThreadPoolExecutor(max_workers=10)
futures = []
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
if not plugin_package:
raise Exception(f"Failed to download plugin {plugin_identifier}")
# upload
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
# Wait for all downloads to complete
for future in futures:
future.result() # This will raise any exceptions that occurred
thread_pool.shutdown(wait=True)
success = []
failed = []
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
# at most 64 plugins one batch
for i in range(0, len(plugin_identifiers_map), 64):
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 64]
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
try:
response = manager.install_from_identifiers(
tenant_id=tenant_id,
identifiers=batch_plugin_identifiers,
source=PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
except Exception:
# add to failed
failed.extend(batch_plugin_identifiers)
continue
if response.all_installed:
success.extend(batch_plugin_identifiers)
continue
task_id = response.task_id
done = False
while not done:
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
for plugin in status.plugins:
if plugin.status == PluginInstallTaskStatus.Success:
success.append(reverse_map[plugin.plugin_unique_identifier])
else:
failed.append(reverse_map[plugin.plugin_unique_identifier])
logger.error(
f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
)
done = True
else:
time.sleep(1)
return {"success": success, "failed": failed}

View File

@ -232,7 +232,7 @@ class PluginService:
tenant_id,
plugin_unique_identifiers,
PluginInstallationSource.Package,
{},
[{}],
)
@staticmethod
@ -246,11 +246,13 @@ class PluginService:
tenant_id,
[plugin_unique_identifier],
PluginInstallationSource.Github,
{
"repo": repo,
"version": version,
"package": package,
},
[
{
"repo": repo,
"version": version,
"package": package,
}
],
)
@staticmethod
@ -277,9 +279,12 @@ class PluginService:
tenant_id,
plugin_unique_identifiers,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": plugin_unique_identifier,
},
[
{
"plugin_unique_identifier": plugin_unique_identifier,
}
for plugin_unique_identifier in plugin_unique_identifiers
],
)
@staticmethod