From 84d118de074ced0ea9c21fed0d62bc315c3e5553 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 1 Apr 2024 02:10:41 +0800 Subject: [PATCH] add redis lock on create collection in multiple thread mode (#3054) Co-authored-by: jyong --- .../rag/datasource/keyword/jieba/jieba.py | 41 ++++---- .../datasource/vdb/milvus/milvus_vector.py | 97 ++++++++++--------- .../datasource/vdb/qdrant/qdrant_vector.py | 73 +++++++------- .../vdb/weaviate/weaviate_vector.py | 22 +++-- 4 files changed, 128 insertions(+), 105 deletions(-) diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 94a692637f..344ef7babe 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -8,6 +8,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.models.document import Document from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment @@ -121,26 +122,28 @@ class Jieba(BaseKeyword): db.session.commit() def _get_dataset_keyword_table(self) -> Optional[dict]: - dataset_keyword_table = self.dataset.dataset_keyword_table - if dataset_keyword_table: - if dataset_keyword_table.keyword_table_dict: - return dataset_keyword_table.keyword_table_dict['__data__']['table'] - else: - dataset_keyword_table = DatasetKeywordTable( - dataset_id=self.dataset.id, - keyword_table=json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) - ) - db.session.add(dataset_keyword_table) - db.session.commit() + lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + with redis_client.lock(lock_name, timeout=20): + dataset_keyword_table = self.dataset.dataset_keyword_table + if dataset_keyword_table: + if dataset_keyword_table.keyword_table_dict: + return dataset_keyword_table.keyword_table_dict['__data__']['table'] + else: + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self.dataset.id, + keyword_table=json.dumps({ + '__type__': 'keyword_table', + '__data__': { + "index_id": self.dataset.id, + "summary": None, + "table": {} + } + }, cls=SetEncoder) + ) + db.session.add(dataset_keyword_table) + db.session.commit() - return {} + return {} def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: for keyword in keywords: diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index f62d603d8d..dcb37ccbe6 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document +from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -61,17 +62,7 @@ class MilvusVector(BaseVector): 'params': {"M": 8, "efConstruction": 64} } metadatas = [d.metadata for d in texts] - - # Grab the existing collection if it exists - from pymilvus import utility - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) - if not utility.has_collection(self._collection_name, using=alias): - self.create_collection(embeddings, metadatas, index_params) + self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -187,46 +178,60 @@ class MilvusVector(BaseVector): def create_collection( self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None - ) -> str: - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + ): + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + # Grab the existing collection if it exists + from pymilvus import utility + alias = uuid4().hex + if self._client_config.secure: + uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) + else: + uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, + password=self._client_config.password) + if not utility.has_collection(self._collection_name, using=alias): + from pymilvus import CollectionSchema, DataType, FieldSchema + from pymilvus.orm.types import infer_dtype_bydata - # Determine embedding dim - dim = len(embeddings[0]) - fields = [] - if metadatas: - fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + if metadatas: + fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) - # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) - # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) - # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + # Create the text field + fields.append( + FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) + ) - # Create the schema for the collection - schema = CollectionSchema(fields) + # Create the schema for the collection + schema = CollectionSchema(fields) - for x in schema.fields: - self._fields.append(x.name) - # Since primary field is auto-id, no need to track it - self._fields.remove(Field.PRIMARY_KEY.value) - - # Create the collection - collection_name = self._collection_name - self._client.create_collection_with_schema(collection_name=collection_name, - schema=schema, index_param=index_params, - consistency_level=self._consistency_level) - return collection_name + for x in schema.fields: + self._fields.append(x.name) + # Since primary field is auto-id, no need to track it + self._fields.remove(Field.PRIMARY_KEY.value) + # Create the collection + collection_name = self._collection_name + self._client.create_collection_with_schema(collection_name=collection_name, + schema=schema, index_param=index_params, + consistency_level=self._consistency_level) + redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: if config.secure: uri = "https://" + str(config.host) + ":" + str(config.port) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 436e6b5f6a..41e8c6154a 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document +from extensions.ext_redis import redis_client if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -77,6 +78,17 @@ class QdrantVector(BaseVector): vector_size = len(embeddings[0]) # get collection name collection_name = self._collection_name + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + lock_name = 'vector_indexing_lock_{}'.format(collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return collection_name = collection_name or uuid.uuid4().hex all_collection_name = [] collections_response = self._client.get_collections() @@ -84,40 +96,35 @@ class QdrantVector(BaseVector): for collection in collection_list: all_collection_name.append(collection.name) if collection_name not in all_collection_name: - # create collection - self.create_collection(collection_name, vector_size) + from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, + max_indexing_threads=0, on_disk=False) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) - self.add_texts(texts, embeddings, **kwargs) - - def create_collection(self, collection_name: str, vector_size: int): - from qdrant_client.http import models as rest - vectors_config = rest.VectorParams( - size=vector_size, - distance=rest.Distance[self._distance_func], - ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) - self._client.recreate_collection( - collection_name=collection_name, - vectors_config=vectors_config, - hnsw_config=hnsw_config, - timeout=int(self._client_config.timeout), - ) - - # create payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD, - field_type=PayloadSchemaType.KEYWORD) - # creat full text index - text_index_params = TextIndexParams( - type=TextIndexType.TEXT, - tokenizer=TokenizerType.MULTILINGUAL, - min_token_len=2, - max_token_len=20, - lowercase=True - ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) + # create payload index + self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, + field_schema=PayloadSchemaType.KEYWORD, + field_type=PayloadSchemaType.KEYWORD) + # creat full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True + ) + self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, + field_schema=text_index_params) + redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5d24ee9fd2..59fbaeee6a 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, root_validator from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document +from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -79,16 +80,23 @@ class WeaviateVector(BaseVector): } def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - - schema = self._default_schema(self._collection_name) - - # check whether the index already exists - if not self._client.schema.contains(schema): - # create collection - self._client.schema.create_class(schema) + # create collection + self._create_collection() # create vector self.add_texts(texts, embeddings) + def _create_collection(self): + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + schema = self._default_schema(self._collection_name) + if not self._client.schema.contains(schema): + # create collection + self._client.schema.create_class(schema) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents]