Feat/vdb migrate command (#2562)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2024-02-26 19:47:29 +08:00 committed by GitHub
parent d93288f711
commit 0620fa3094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 134 additions and 57 deletions

View File

@ -1,20 +1,21 @@
import base64 import base64
import json import json
import secrets import secrets
from typing import cast
import click import click
from flask import current_app from flask import current_app
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding from core.rag.datasource.vdb.vector_factory import Vector
from core.model_manager import ModelManager from core.rag.models.document import Document
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import Tenant from models.account import Tenant
from models.dataset import Dataset from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account from models.model import Account
from models.provider import Provider, ProviderModel from models.provider import Provider, ProviderModel
@ -124,14 +125,15 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.') @click.command('vdb-migrate', help='migrate vector db.')
def create_qdrant_indexes(): def vdb_migrate():
""" """
Migrate other vector database datas to Qdrant. Migrate vector database datas to target vector database .
""" """
click.echo(click.style('Start create qdrant indexes.', fg='green')) click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0 create_count = 0
config = cast(dict, current_app.config)
vector_type = config.get('VECTOR_STORE')
page = 1 page = 1
while True: while True:
try: try:
@ -140,50 +142,97 @@ def create_qdrant_indexes():
except NotFound: except NotFound:
break break
model_manager = ModelManager()
page += 1 page += 1
for dataset in datasets: for dataset in datasets:
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict: if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant': if dataset.index_struct_dict['type'] == vector_type:
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
continue continue
embeddings = CacheEmbedding(embedding_model) if vector_type == "weaviate":
dataset_id = dataset.id
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
index = QdrantVectorIndex( "type": 'weaviate',
dataset=dataset, "vector_store": {"class_prefix": collection_name}
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
} }
dataset.index_struct = json.dumps(index_struct) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "qdrant":
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
try:
vector.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
vector.create(documents)
except Exception as e:
raise e
click.echo(f"Dataset {dataset.id} create successfully.")
db.session.add(dataset)
db.session.commit() db.session.commit()
create_count += 1 create_count += 1
else:
click.echo('passed.')
except Exception as e: except Exception as e:
db.session.rollback()
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red')) fg='red'))
@ -196,4 +245,4 @@ def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes) app.cli.add_command(vdb_migrate)

View File

@ -664,6 +664,7 @@ class IndexingRunner:
) )
# load index # load index
index_processor.load(dataset, chunk_documents) index_processor.load(dataset, chunk_documents)
db.session.add(dataset)
document_ids = [document.metadata['doc_id'] for document in chunk_documents] document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).filter(

View File

@ -127,9 +127,15 @@ class MilvusVector(BaseVector):
self._client.delete(collection_name=self._collection_name, pks=doc_ids) self._client.delete(collection_name=self._collection_name, pks=doc_ids)
def delete(self) -> None: def delete(self) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
from pymilvus import utility from pymilvus import utility
utility.drop_collection(self._collection_name, None) utility.drop_collection(self._collection_name, None, using=alias)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:

View File

@ -1,3 +1,4 @@
import json
from typing import Any, cast from typing import Any, cast
from flask import current_app from flask import current_app
@ -39,6 +40,11 @@ class Vector:
else: else:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return WeaviateVector( return WeaviateVector(
collection_name=collection_name, collection_name=collection_name,
config=WeaviateConfig( config=WeaviateConfig(
@ -66,6 +72,13 @@ class Vector:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
if not self._dataset.index_struct_dict:
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return QdrantVector( return QdrantVector(
collection_name=collection_name, collection_name=collection_name,
group_id=self._dataset.id, group_id=self._dataset.id,
@ -84,6 +97,11 @@ class Vector:
else: else:
dataset_id = self._dataset.id dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return MilvusVector( return MilvusVector(
collection_name=collection_name, collection_name=collection_name,
config=MilvusConfig( config=MilvusConfig(

View File

@ -127,6 +127,9 @@ class WeaviateVector(BaseVector):
) )
def delete(self): def delete(self):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
self._client.schema.delete_class(self._collection_name) self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool: