Lindorm vdb (#11574)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
This commit is contained in:
Jiang 2024-12-12 09:43:27 +08:00 committed by GitHub
parent 926f604f09
commit 0d04cdc323
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 157 additions and 123 deletions

View File

@ -294,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
# OceanBase Vector configuration # OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1 OCEANBASE_VECTOR_HOST=127.0.0.1

View File

@ -21,3 +21,14 @@ class LindormConfig(BaseSettings):
description="Lindorm password", description="Lindorm password",
default=None, default=None,
) )
DEFAULT_INDEX_TYPE: Optional[str] = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: Optional[bool] = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)

View File

@ -1,13 +1,10 @@
import copy import copy
import json import json
import logging import logging
from collections.abc import Iterable
from typing import Any, Optional from typing import Any, Optional
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -23,11 +20,15 @@ logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN) logging.getLogger("lindorm").setLevel(logging.WARN)
ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel): class LindormVectorStoreConfig(BaseModel):
hosts: str hosts: str
username: Optional[str] = None username: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
using_ugc: Optional[bool] = False
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -41,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
return values return values
def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> dict[str, Any]:
params = { params = {"hosts": self.hosts}
"hosts": self.hosts,
}
if self.username and self.password: if self.username and self.password:
params["http_auth"] = (self.username, self.password) params["http_auth"] = (self.username, self.password)
return params return params
@ -51,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel):
class LindormVectorStore(BaseVector): class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower()) self._routing = None
self._routing_field = None
if config.using_ugc:
routing_value: str = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
self._routing_field = ROUTING_FIELD
ugc_index_name = collection_name
super().__init__(ugc_index_name.lower())
else:
super().__init__(collection_name.lower())
self._client_config = config self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params()) self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = config.using_ugc
self.kwargs = kwargs self.kwargs = kwargs
def get_type(self) -> str: def get_type(self) -> str:
@ -66,89 +77,37 @@ class LindormVectorStore(BaseVector):
def refresh(self): def refresh(self):
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)
def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name, "_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch ids: {batch_ids}")
return set()
if ids is None:
return texts, metadatas, ids
if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")
filtered_texts = []
filtered_metadatas = []
filtered_ids = []
def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx : min(idx + n, length)]
for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = [] actions = []
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
for i in range(len(documents)): for i in range(len(documents)):
action = { action_header = {
"_op_type": "index", "index": {
"_index": self._collection_name.lower(), "_index": self.collection_name.lower(),
"_id": uuids[i], "_id": uuids[i],
"_source": { }
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
} }
actions.append(action) action_values = {
bulk(self._client, actions) Field.CONTENT_KEY.value: documents[i].page_content,
self.refresh() Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
else:
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query) response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]: if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]] return [hit["_id"] for hit in response["hits"]["hits"]]
@ -156,50 +115,62 @@ class LindormVectorStore(BaseVector):
return None return None
def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} ids = self.get_ids_by_metadata_field(key, value)
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids: if ids:
self.delete_by_ids(ids) self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
params = {}
if self._using_ugc:
params["routing"] = self._routing
for id in ids: for id in ids:
if self._client.exists(index=self._collection_name, id=id): if self._client.exists(index=self._collection_name, id=id, params=params):
self._client.delete(index=self._collection_name, id=id) params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else: else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None: def delete(self) -> None:
try: if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name): if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success") logger.info("Delete index success")
else: else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
try: try:
self._client.get(index=self._collection_name, id=id) params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
return True return True
except: except:
return False return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list): if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats") raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector): if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats") raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try: try:
response = self._client.search(index=self._collection_name, body=query) params = {}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception as e: except Exception as e:
logger.exception(f"Error executing vector search, query: {query}") logger.exception(f"Error executing vector search, query: {query}")
raise raise
@ -232,7 +203,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match = kwargs.get("minimum_should_match", 0) minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter") filters = kwargs.get("filter")
routing = kwargs.get("routing") routing = self._routing
full_text_query = default_text_search_query( full_text_query = default_text_search_query(
query_text=query, query_text=query,
k=top_k, k=top_k,
@ -243,6 +214,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match=minimum_should_match, minimum_should_match=minimum_should_match,
filters=filters, filters=filters,
routing=routing, routing=routing,
routing_field=self._routing_field,
) )
response = self._client.search(index=self._collection_name, body=full_text_query) response = self._client.search(index=self._collection_name, body=full_text_query)
docs = [] docs = []
@ -265,17 +237,18 @@ class LindormVectorStore(BaseVector):
logger.info(f"Collection {self._collection_name} already exists.") logger.info(f"Collection {self._collection_name} already exists.")
return return
if self._client.indices.exists(index=self._collection_name): if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.") logger.info(f"{self._collection_name.lower()} already exists.")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return return
if len(self.kwargs) == 0 and len(kwargs) != 0: if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs) self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value) vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2) shards = kwargs.pop("shards", 4)
engine = kwargs.pop("engine", "lvector") engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw") method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
data_type = kwargs.pop("data_type", "float") data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")
hnsw_m = kwargs.pop("hnsw_m", 24) hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
@ -288,10 +261,10 @@ class LindormVectorStore(BaseVector):
mapping = default_text_mapping( mapping = default_text_mapping(
dimension, dimension,
method_name, method_name,
space_type=space_type,
shards=shards, shards=shards,
engine=engine, engine=engine,
data_type=data_type, data_type=data_type,
space_type=space_type,
vector_field=vector_field, vector_field=vector_field,
hnsw_m=hnsw_m, hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction, hnsw_ef_construction=hnsw_ef_construction,
@ -301,6 +274,7 @@ class LindormVectorStore(BaseVector):
centroids_hnsw_m=centroids_hnsw_m, centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search, centroids_hnsw_ef_search=centroids_hnsw_ef_search,
using_ugc=self._using_ugc,
**kwargs, **kwargs,
) )
self._client.indices.create(index=self._collection_name.lower(), body=mapping) self._client.indices.create(index=self._collection_name.lower(), body=mapping)
@ -309,15 +283,20 @@ class LindormVectorStore(BaseVector):
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source") excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word") analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"] engine = kwargs["engine"]
shard = kwargs["shards"] shard = kwargs["shards"]
space_type = kwargs["space_type"] space_type = kwargs.get("space_type")
if space_type is None:
if method_name == "hnsw":
space_type = "l2"
else:
space_type = "cosine"
data_type = kwargs["data_type"] data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value) vector_field = kwargs.get("vector_field", Field.VECTOR.value)
using_ugc = kwargs.get("using_ugc", False)
if method_name == "ivfpq": if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"] ivfpq_m = kwargs["ivfpq_m"]
@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
if excludes_from_source: if excludes_from_source:
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
if method_name == "ivfpq" and routing_field is not None: if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True mapping["settings"]["index"]["knn.offline.construction"] = True
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
if method_name == "flat" and routing_field is not None:
mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn_routing"] = True
return mapping return mapping
@ -386,14 +363,12 @@ def default_text_search_query(
minimum_should_match: int = 0, minimum_should_match: int = 0,
filters: Optional[list[dict]] = None, filters: Optional[list[dict]] = None,
routing: Optional[str] = None, routing: Optional[str] = None,
routing_field: Optional[str] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
if routing is not None: if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = { query_clause = {
"bool": { "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
} }
else: else:
query_clause = {"match": {text_field: query_text}} query_clause = {"match": {text_field: query_text}}
@ -483,16 +458,40 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory): class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig( lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL, hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME, username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD, password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
) )
return LindormVectorStore(collection_name, lindorm_config) using_ugc = dify_config.USING_UGC_INDEX
routing_value = None
if dataset.index_struct:
if using_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
index_type = dify_config.DEFAULT_INDEX_TYPE
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": index_type,
"dimension": dimension,
"distance_type": distance_type,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = class_prefix
else:
index_name = class_prefix
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)

View File

@ -7,9 +7,10 @@ env = environs.Env()
class Config: class Config:
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070") SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = env.str("SEARCH_PWD", "PWD") SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
USING_UGC = env.bool("USING_UGC", True)
class TestLindormVectorStore(AbstractVectorTest): class TestLindormVectorStore(AbstractVectorTest):
@ -31,5 +32,27 @@ class TestLindormVectorStore(AbstractVectorTest):
assert ids[0] == self.example_doc_id assert ids[0] == self.example_doc_id
def test_lindorm_vector(setup_mock_redis): class TestLindormVectorStoreUGC(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = LindormVectorStore(
collection_name="ugc_index_test",
config=LindormVectorStoreConfig(
hosts=Config.SEARCH_ENDPOINT,
username=Config.SEARCH_USERNAME,
password=Config.SEARCH_PWD,
using_ugc=Config.USING_UGC,
),
routing_value=self.collection_name,
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def test_lindorm_vector_ugc(setup_mock_redis):
TestLindormVectorStore().run_all_tests() TestLindormVectorStore().run_all_tests()
TestLindormVectorStoreUGC().run_all_tests()