Fix/qdrant data issue (#1203)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2023-09-22 14:21:26 +08:00 committed by GitHub
parent e409895c02
commit 724e053732
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 171 additions and 149 deletions

View File

@ -3,12 +3,13 @@ import json
import math import math
import random import random
import string import string
import threading
import time import time
import uuid import uuid
import click import click
from tqdm import tqdm from tqdm import tqdm
from flask import current_app from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -456,92 +457,92 @@ def update_qdrant_indexes():
@click.command('normalization-collections', help='restore all collections in one') @click.command('normalization-collections', help='restore all collections in one')
def normalization_collections(): def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green')) click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0 normalization_count = []
page = 1 page = 1
while True: while True:
try: try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound: except NotFound:
break break
datasets_result = datasets.items
page += 1 page += 1
for dataset in datasets: for i in range(0, len(datasets_result), 5):
if not dataset.collection_binding_id: threads = []
try: sub_datasets = datasets_result[i:i + 5]
click.echo('restore dataset index: {}'.format(dataset.id)) for dataset in sub_datasets:
try: document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
embedding_model = ModelFactory.get_embedding_model( 'flask_app': current_app._get_current_object(),
tenant_id=dataset.tenant_id, 'dataset': dataset,
model_provider_name=dataset.embedding_model_provider, 'normalization_count': normalization_count
model_name=dataset.embedding_model })
) threads.append(document_format_thread)
except Exception: document_format_thread.start()
provider = Provider( for thread in threads:
id='provider_id', thread.join()
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding: click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex( def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
dataset=dataset, with flask_app.app_context():
config=QdrantConfig( try:
endpoint=current_app.config.get('QDRANT_URL'), click.echo('restore dataset index: {}'.format(dataset.id))
api_key=current_app.config.get('QDRANT_API_KEY'), try:
root_path=current_app.root_path embedding_model = ModelFactory.get_embedding_model(
), tenant_id=dataset.tenant_id,
embeddings=embeddings model_provider_name=dataset.embedding_model_provider,
) model_name=dataset.embedding_model
if index: )
index.restore_dataset_in_one(dataset, dataset_collection_binding) except Exception:
else: provider = Provider(
click.echo('passed.') id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
original_index = QdrantVectorIndex( if not dataset_collection_binding:
dataset=dataset, dataset_collection_binding = DatasetCollectionBinding(
config=QdrantConfig( provider_name=embedding_model.model_provider.provider_name,
endpoint=current_app.config.get('QDRANT_URL'), model_name=embedding_model.name,
api_key=current_app.config.get('QDRANT_API_KEY'), collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
root_path=current_app.root_path )
), db.session.add(dataset_collection_binding)
embeddings=embeddings db.session.commit()
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green')) from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
normalization_count.append(1)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')

View File

@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex):
def delete_by_group_id(self, group_id: str) -> None: def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete() vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None: def delete(self) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
if documents: if documents:
try: try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name) self.add_texts(documents)
except Exception as e: except Exception as e:
raise e raise e

View File

@ -1390,70 +1390,12 @@ class Qdrant(VectorStore):
path=path, path=path,
**kwargs, **kwargs,
) )
try: all_collection_name = []
# Skip any validation in case of forced collection recreate. collections_response = client.get_collections()
if force_recreate: collection_list = collections_response.collections
raise ValueError for collection in collection_list:
all_collection_name.append(collection.name)
# Get the vector configuration of the existing collection and vector, if it if collection_name not in all_collection_name:
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams( vectors_config = rest.VectorParams(
size=vector_size, size=vector_size,
distance=rest.Distance[distance_func], distance=rest.Distance[distance_func],
@ -1481,6 +1423,67 @@ class Qdrant(VectorStore):
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
is_new_collection = True is_new_collection = True
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
qdrant = cls( qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,

View File

@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
], ],
)) ))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:

View File

@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect @dataset_was_deleted.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset = sender dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct) clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id)

View File

@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase
@shared_task(queue='dataset') @shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str): def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str):
""" """
Clean dataset when dataset deleted. Clean dataset when dataset deleted.
:param dataset_id: dataset id :param dataset_id: dataset id
:param tenant_id: tenant id :param tenant_id: tenant id
:param indexing_technique: indexing technique :param indexing_technique: indexing technique
:param index_struct: index struct dict :param index_struct: index struct dict
:param collection_binding_id: collection binding id
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
""" """
@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
id=dataset_id, id=dataset_id,
tenant_id=tenant_id, tenant_id=tenant_id,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct index_struct=index_struct,
collection_binding_id=collection_binding_id
) )
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset) vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try: try:
vector_index.delete() vector_index.delete_by_group_id(dataset.id)
except Exception: except Exception:
logging.exception("Delete doc index failed when dataset deleted.") logging.exception("Delete doc index failed when dataset deleted.")

View File

@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
raise Exception('Dataset not found') raise Exception('Dataset not found')
if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index.delete() index.delete_by_group_id(dataset.id)
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,