diff --git a/api/commands.py b/api/commands.py index cbd9fb3711..dba59da628 100644 --- a/api/commands.py +++ b/api/commands.py @@ -339,26 +339,7 @@ def create_qdrant_indexes(): ) except Exception: - try: - embedding_model = model_manager.get_default_model_instance( - tenant_id=dataset.tenant_id, - model_type=ModelType.TEXT_EMBEDDING, - ) - dataset.embedding_model = embedding_model.model - dataset.embedding_model_provider = embedding_model.provider - except Exception: - - provider = Provider( - id='provider_id', - tenant_id=dataset.tenant_id, - provider_name='openai', - provider_type=ProviderType.SYSTEM.value, - encrypted_config=json.dumps({'openai_api_key': 'TEST'}), - is_valid=True, - ) - model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", - model_provider=model_provider) + continue embeddings = CacheEmbedding(embedding_model) from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex @@ -405,7 +386,7 @@ def update_qdrant_indexes(): .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) except NotFound: break - + model_manager = ModelManager() page += 1 for dataset in datasets: if dataset.index_struct_dict: @@ -413,23 +394,15 @@ def update_qdrant_indexes(): try: click.echo('Update dataset qdrant index: {}'.format(dataset.id)) try: - embedding_model = ModelFactory.get_embedding_model( + embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) except Exception: - provider = Provider( - id='provider_id', - tenant_id=dataset.tenant_id, - provider_name='openai', - provider_type=ProviderType.CUSTOM.value, - encrypted_config=json.dumps({'openai_api_key': 'TEST'}), - is_valid=True, - ) - model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", - model_provider=model_provider) + continue embeddings = CacheEmbedding(embedding_model) from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex @@ -524,23 +497,17 @@ def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: try: click.echo('restore dataset index: {}'.format(dataset.id)) try: - embedding_model = ModelFactory.get_embedding_model( + model_manager = ModelManager() + + embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, - model_provider_name=dataset.embedding_model_provider, - model_name=dataset.embedding_model + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model + ) except Exception: - provider = Provider( - id='provider_id', - tenant_id=dataset.tenant_id, - provider_name='openai', - provider_type=ProviderType.CUSTOM.value, - encrypted_config=json.dumps({'openai_api_key': 'TEST'}), - is_valid=True, - ) - model_provider = OpenAIProvider(provider=provider) - embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", - model_provider=model_provider) + pass embeddings = CacheEmbedding(embedding_model) dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,