dify/api/services/plugin/data_migration.py

204 lines
8.7 KiB
Python

import json
import logging
import click
from core.entities import DEFAULT_PLUGIN_ID
from models.engine import db
logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table
cls.migrate_db_records("provider_models", "provider_name")
cls.migrate_db_records("provider_orders", "provider_name")
cls.migrate_db_records("tenant_default_models", "provider_name")
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
cls.migrate_db_records("provider_model_settings", "provider_name")
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
cls.migrate_db_records("tool_builtin_providers", "provider")
@classmethod
def migrate_datasets(cls) -> None:
table_name = "datasets"
provider_column_name = "embedding_model_provider"
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
while True:
sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
current_iter_count = 0
for i in rs:
record_id = str(i.id)
provider_name = str(i.provider_name)
retrieval_model = i.retrieval_model
print(type(retrieval_model))
if record_id in failed_ids:
continue
retrieval_model_changed = False
if retrieval_model:
if (
"reranking_model" in retrieval_model
and "reranking_provider_name" in retrieval_model["reranking_model"]
and retrieval_model["reranking_model"]["reranking_provider_name"]
and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
):
click.echo(
click.style(
f"[{processed_count}] Migrating {table_name} {record_id} "
f"(reranking_provider_name: "
f"{retrieval_model['reranking_model']['reranking_provider_name']})",
fg="white",
)
)
retrieval_model["reranking_model"]["reranking_provider_name"] = (
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
)
retrieval_model_changed = True
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
# update provider name append with "langgenius/{provider_name}/{provider_name}"
params = {"record_id": record_id}
update_retrieval_model_sql = ""
if retrieval_model and retrieval_model_changed:
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
sql = f"""update {table_name}
set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
fg="green",
)
)
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
)
continue
current_iter_count += 1
processed_count += 1
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)
@classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
failed_ids = []
last_id = "00000000-0000-0000-0000-000000000000"
while True:
sql = f"""
SELECT id, {provider_column_name} AS provider_name
FROM {table_name}
WHERE {provider_column_name} NOT LIKE '%/%'
AND {provider_column_name} IS NOT NULL
AND {provider_column_name} != ''
AND id > :last_id
ORDER BY id ASC
LIMIT 5000
"""
params = {"last_id": last_id or ""}
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql), params)
current_iter_count = 0
batch_updates = []
for i in rs:
current_iter_count += 1
processed_count += 1
record_id = str(i.id)
last_id = record_id
provider_name = str(i.provider_name)
if record_id in failed_ids:
continue
click.echo(
click.style(
f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
fg="white",
)
)
try:
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}"
batch_updates.append((updated_value, record_id))
except Exception as e:
failed_ids.append(record_id)
click.echo(
click.style(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
fg="red",
)
)
logger.exception(
f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
)
continue
if batch_updates:
update_sql = f"""
UPDATE {table_name}
SET {provider_column_name} = :updated_value
WHERE id = :record_id
"""
conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
click.echo(
click.style(
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
fg="green",
)
)
if not current_iter_count:
break
click.echo(
click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
)