Feat/improve vector database logic (#1193)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2023-09-18 18:15:41 +08:00 committed by GitHub
parent 60e0bbd713
commit 269a465fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 463 additions and 46 deletions

View File

@ -4,6 +4,7 @@ import math
import random
import string
import time
import uuid
import click
from tqdm import tqdm
@ -23,7 +24,7 @@ from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
import secrets
import base64
@ -239,7 +240,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
@ -346,7 +353,8 @@ def create_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@ -364,7 +372,8 @@ def create_qdrant_indexes():
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
@ -373,7 +382,8 @@ def create_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@ -414,7 +424,8 @@ def update_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@ -435,11 +446,104 @@ def update_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
@ -473,7 +577,7 @@ def update_app_model_configs(batch_size):
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
@ -485,14 +589,14 @@ def update_app_model_configs(batch_size):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
@ -512,7 +616,7 @@ def update_app_model_configs(batch_size):
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()
account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
@ -534,13 +638,15 @@ def update_app_model_configs(batch_size):
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@ -551,4 +657,5 @@ def register_commands(app):
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)

View File

@ -16,6 +16,10 @@ class BaseIndex(ABC):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@ -28,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError

View File

@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table)
db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',

View File

@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument
@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@ -28,6 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -84,6 +85,7 @@ class Qdrant(VectorStore):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None
def __init__(
@ -93,9 +95,12 @@ class Qdrant(VectorStore):
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
):
"""Initialize with necessary components."""
try:
@ -129,7 +134,10 @@ class Qdrant(VectorStore):
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection
if embedding_function is not None:
warnings.warn(
@ -170,6 +178,8 @@ class Qdrant(VectorStore):
batch_size:
How many vectors upload per-request.
Default: 64
group_id:
collection group
Returns:
List of ids from adding the texts into the vectorstore.
@ -182,7 +192,11 @@ class Qdrant(VectorStore):
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
return added_ids
@sync_call_fallback
@ -970,6 +984,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
distance_func,
content_payload_key,
metadata_payload_key,
group_payload_key,
group_id,
vector_name,
shard_number,
replication_factor,
@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient(
location=location,
url=url,
@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
qdrant = cls(
client=client,
collection_name=collection_name,
@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
)
return qdrant
@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
else:
out.append(
rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
key=key,
match=rest.MatchValue(value=value),
)
)
@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
self.group_id,
self.group_payload_key
),
)
]

View File

@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content'
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
))
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

View File

@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
return_resource: str
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k
'k': self.k,
'filter': {
'group_id': [dataset.id]
}
}
)
else:

View File

@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
self.client.delete_collection(collection_name=self.collection_name)
def delete_group(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,

View File

@ -0,0 +1,47 @@
"""add_dataset_collection_binding
Revision ID: 6e2cfb077b04
Revises: 77e83833755c
Create Date: 2023-09-13 22:16:48.027810
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6e2cfb077b04'
down_revision = '77e83833755c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_collection_bindings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('provider_name', sa.String(length=40), nullable=False),
sa.Column('model_name', sa.String(length=40), nullable=False),
sa.Column('collection_name', sa.String(length=64), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey')
)
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('collection_binding_id')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_index('provider_model_name_idx')
op.drop_table('dataset_collection_bindings')
# ### end Alembic commands ###

View File

@ -38,6 +38,8 @@ class Dataset(db.Model):
server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
@property
def dataset_keyword_table(self):
@ -445,3 +447,19 @@ class Embedding(db.Model):
def get_embedding(self) -> list[float]:
return pickle.loads(self.embedding)
class DatasetCollectionBinding(db.Model):
__tablename__ = 'dataset_collection_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
db.Index('provider_model_name_idx', 'provider_name', 'model_name')
)
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from libs import helper
from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
DatasetCollectionBinding
from models.model import UploadFile
from models.source import DataSourceBinding
from services.errors.account import NoPermissionError
@ -147,6 +148,7 @@ class DatasetService:
action = 'remove'
filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None
filtered_data['collection_binding_id'] = None
elif data['indexing_technique'] == 'high_quality':
action = 'add'
# get embedding model setting
@ -156,6 +158,11 @@ class DatasetService:
)
filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
f"No Embedding Model available. Please configure a valid provider "
@ -464,7 +471,11 @@ class DocumentService:
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id
documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@ -720,10 +731,16 @@ class DocumentService:
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None
dataset_collection_binding_id = None
if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id
)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@ -732,7 +749,8 @@ class DocumentService:
indexing_technique=document_data["indexing_technique"],
created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
)
db.session.add(dataset)
@ -1069,3 +1087,23 @@ class SegmentService:
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment)
db.session.commit()
class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.flush()
return dataset_collection_binding

View File

@ -47,7 +47,10 @@ class HitTestingService:
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 10
'k': 10,
'filter': {
'group_id': [dataset.id]
}
}
)
end = time.perf_counter()