diff --git a/api/commands.py b/api/commands.py index f2809be8e7..cd250a0b59 100644 --- a/api/commands.py +++ b/api/commands.py @@ -25,6 +25,7 @@ from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel from services.account_service import RegisterService, TenantService +from services.plugin.data_migration import PluginDataMigration @click.command("reset-password", help="Reset the account password.") @@ -639,6 +640,18 @@ where sites.id is null limit 1000""" click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) +@click.command("migrate-data-for-plugin", help="Migrate data for plugin.") +def migrate_data_for_plugin(): + """ + Migrate data for plugin. + """ + click.echo(click.style("Starting migrate data for plugin.", fg="white")) + + PluginDataMigration.migrate() + + click.echo(click.style("Migrate data for plugin completed.", fg="green")) + + def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) @@ -649,3 +662,4 @@ def register_commands(app): app.cli.add_command(create_tenant) app.cli.add_command(upgrade_db) app.cli.add_command(fix_app_site_missing) + app.cli.add_command(migrate_data_for_plugin) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index acc1a2d35b..3cd610464d 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,4 +1,5 @@ from core.app.app_config.entities import ModelConfigEntity +from core.entities import DEFAULT_PLUGIN_ID from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager @@ -53,7 +54,15 @@ class ModelConfigManager: model_provider_factory = ModelProviderFactory(tenant_id) provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] - if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: + if "provider" not in config["model"]: + raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") + + if "/" not in config["model"]["provider"]: + config["model"]["provider"] = ( + f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}" + ) + + if config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 764221dec5..534e00fdd9 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -9,6 +9,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from constants import HIDDEN_VALUE +from core.entities import DEFAULT_PLUGIN_ID from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, @@ -1047,6 +1048,9 @@ class ProviderConfigurations(BaseModel): return list(self.values()) def __getitem__(self, key): + if "/" not in key: + key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" + return self.configurations[key] def __setitem__(self, key, value): @@ -1059,6 +1063,9 @@ class ProviderConfigurations(BaseModel): return iter(self.configurations.values()) def get(self, key, default=None): + if "/" not in key: + key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" + return self.configurations.get(key, default) diff --git a/api/services/plugin/__init__.py b/api/services/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py new file mode 100644 index 0000000000..27aa308e55 --- /dev/null +++ b/api/services/plugin/data_migration.py @@ -0,0 +1,184 @@ +import json +import logging + +import click + +from core.entities import DEFAULT_PLUGIN_ID +from extensions.ext_database import db + +logger = logging.getLogger(__name__) + + +class PluginDataMigration: + @classmethod + def migrate(cls) -> None: + cls.migrate_db_records("providers", "provider_name") # large table + cls.migrate_db_records("provider_models", "provider_name") + cls.migrate_db_records("provider_orders", "provider_name") + cls.migrate_db_records("tenant_default_models", "provider_name") + cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") + cls.migrate_db_records("provider_model_settings", "provider_name") + cls.migrate_db_records("load_balancing_model_configs", "provider_name") + cls.migrate_datasets() + cls.migrate_db_records("embeddings", "provider_name") # large table + cls.migrate_db_records("dataset_collection_bindings", "provider_name") + + @classmethod + def migrate_datasets(cls) -> None: + table_name = "datasets" + provider_column_name = "embedding_model_provider" + + click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) + + processed_count = 0 + failed_ids = [] + while True: + sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name} +where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != '' +limit 1000""" + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql)) + + current_iter_count = 0 + for i in rs: + record_id = str(i.id) + provider_name = str(i.provider_name) + retrieval_model = i.retrieval_model + print(type(retrieval_model)) + + if record_id in failed_ids: + continue + + retrieval_model_changed = False + if retrieval_model: + if ( + "reranking_model" in retrieval_model + and "reranking_provider_name" in retrieval_model["reranking_model"] + and retrieval_model["reranking_model"]["reranking_provider_name"] + and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"] + ): + click.echo( + click.style( + f"[{processed_count}] Migrating {table_name} {record_id} " + f"(reranking_provider_name: " + f"{retrieval_model['reranking_model']['reranking_provider_name']})", + fg="white", + ) + ) + retrieval_model["reranking_model"]["reranking_provider_name"] = ( + f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" + ) + retrieval_model_changed = True + + click.echo( + click.style( + f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})", + fg="white", + ) + ) + + try: + # update provider name append with "langgenius/{provider_name}/{provider_name}" + params = {"record_id": record_id} + update_retrieval_model_sql = "" + if retrieval_model and retrieval_model_changed: + update_retrieval_model_sql = ", retrieval_model = :retrieval_model" + params["retrieval_model"] = json.dumps(retrieval_model) + + sql = f"""update {table_name} + set {provider_column_name} = + concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) + {update_retrieval_model_sql} + where id = :record_id""" + conn.execute(db.text(sql), params) + click.echo( + click.style( + f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})", + fg="green", + ) + ) + except Exception: + failed_ids.append(record_id) + click.echo( + click.style( + f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})", + fg="red", + ) + ) + logger.exception( + f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})" + ) + continue + + current_iter_count += 1 + processed_count += 1 + + if not current_iter_count: + break + + click.echo( + click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green") + ) + + @classmethod + def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: + click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) + + processed_count = 0 + failed_ids = [] + while True: + sql = f"""select id, {provider_column_name} as provider_name from {table_name} +where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != '' +limit 1000""" + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql)) + + current_iter_count = 0 + for i in rs: + current_iter_count += 1 + processed_count += 1 + record_id = str(i.id) + provider_name = str(i.provider_name) + + if record_id in failed_ids: + continue + + click.echo( + click.style( + f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})", + fg="white", + ) + ) + + try: + # update provider name append with "langgenius/{provider_name}/{provider_name}" + sql = f"""update {table_name} + set {provider_column_name} = + concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) + where id = :record_id""" + conn.execute(db.text(sql), {"record_id": record_id}) + click.echo( + click.style( + f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})", + fg="green", + ) + ) + except Exception: + failed_ids.append(record_id) + click.echo( + click.style( + f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})", + fg="red", + ) + ) + logger.exception( + f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})" + ) + continue + + if not current_iter_count: + break + + click.echo( + click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green") + )