mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:39:04 +08:00
Lindorm vdb (#11574)
Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
This commit is contained in:
parent
926f604f09
commit
0d04cdc323
@ -294,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30
|
|||||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||||
LINDORM_USERNAME=admin
|
LINDORM_USERNAME=admin
|
||||||
LINDORM_PASSWORD=admin
|
LINDORM_PASSWORD=admin
|
||||||
|
USING_UGC_INDEX=False
|
||||||
|
|
||||||
# OceanBase Vector configuration
|
# OceanBase Vector configuration
|
||||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||||
|
@ -21,3 +21,14 @@ class LindormConfig(BaseSettings):
|
|||||||
description="Lindorm password",
|
description="Lindorm password",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
DEFAULT_INDEX_TYPE: Optional[str] = Field(
|
||||||
|
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||||
|
default="hnsw",
|
||||||
|
)
|
||||||
|
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
|
||||||
|
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||||
|
)
|
||||||
|
USING_UGC_INDEX: Optional[bool] = Field(
|
||||||
|
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Iterable
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from opensearchpy import OpenSearch
|
from opensearchpy import OpenSearch
|
||||||
from opensearchpy.helpers import bulk
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
@ -23,11 +20,15 @@ logger = logging.getLogger(__name__)
|
|||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
logging.getLogger("lindorm").setLevel(logging.WARN)
|
logging.getLogger("lindorm").setLevel(logging.WARN)
|
||||||
|
|
||||||
|
ROUTING_FIELD = "routing_field"
|
||||||
|
UGC_INDEX_PREFIX = "ugc_index"
|
||||||
|
|
||||||
|
|
||||||
class LindormVectorStoreConfig(BaseModel):
|
class LindormVectorStoreConfig(BaseModel):
|
||||||
hosts: str
|
hosts: str
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
|
using_ugc: Optional[bool] = False
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -41,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def to_opensearch_params(self) -> dict[str, Any]:
|
def to_opensearch_params(self) -> dict[str, Any]:
|
||||||
params = {
|
params = {"hosts": self.hosts}
|
||||||
"hosts": self.hosts,
|
|
||||||
}
|
|
||||||
if self.username and self.password:
|
if self.username and self.password:
|
||||||
params["http_auth"] = (self.username, self.password)
|
params["http_auth"] = (self.username, self.password)
|
||||||
return params
|
return params
|
||||||
@ -51,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel):
|
|||||||
|
|
||||||
class LindormVectorStore(BaseVector):
|
class LindormVectorStore(BaseVector):
|
||||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
||||||
super().__init__(collection_name.lower())
|
self._routing = None
|
||||||
|
self._routing_field = None
|
||||||
|
if config.using_ugc:
|
||||||
|
routing_value: str = kwargs.get("routing_value")
|
||||||
|
if routing_value is None:
|
||||||
|
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||||
|
self._routing = routing_value.lower()
|
||||||
|
self._routing_field = ROUTING_FIELD
|
||||||
|
ugc_index_name = collection_name
|
||||||
|
super().__init__(ugc_index_name.lower())
|
||||||
|
else:
|
||||||
|
super().__init__(collection_name.lower())
|
||||||
self._client_config = config
|
self._client_config = config
|
||||||
self._client = OpenSearch(**config.to_opensearch_params())
|
self._client = OpenSearch(**config.to_opensearch_params())
|
||||||
|
self._using_ugc = config.using_ugc
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
@ -66,89 +77,37 @@ class LindormVectorStore(BaseVector):
|
|||||||
def refresh(self):
|
def refresh(self):
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
self._client.indices.refresh(index=self._collection_name)
|
||||||
|
|
||||||
def __filter_existed_ids(
|
|
||||||
self,
|
|
||||||
texts: list[str],
|
|
||||||
metadatas: list[dict],
|
|
||||||
ids: list[str],
|
|
||||||
bulk_size: int = 1024,
|
|
||||||
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
|
||||||
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
|
|
||||||
try:
|
|
||||||
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
|
|
||||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Error fetching batch {batch_ids}")
|
|
||||||
return set()
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
|
||||||
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
|
|
||||||
try:
|
|
||||||
existing_docs = self._client.mget(
|
|
||||||
body={
|
|
||||||
"docs": [
|
|
||||||
{"_index": self._collection_name, "_id": id, "routing": routing}
|
|
||||||
for id, routing in zip(batch_ids, route_ids)
|
|
||||||
]
|
|
||||||
},
|
|
||||||
_source=False,
|
|
||||||
)
|
|
||||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Error fetching batch ids: {batch_ids}")
|
|
||||||
return set()
|
|
||||||
|
|
||||||
if ids is None:
|
|
||||||
return texts, metadatas, ids
|
|
||||||
|
|
||||||
if len(texts) != len(ids):
|
|
||||||
raise RuntimeError(f"texts {len(texts)} != {ids}")
|
|
||||||
|
|
||||||
filtered_texts = []
|
|
||||||
filtered_metadatas = []
|
|
||||||
filtered_ids = []
|
|
||||||
|
|
||||||
def batch(iterable, n):
|
|
||||||
length = len(iterable)
|
|
||||||
for idx in range(0, length, n):
|
|
||||||
yield iterable[idx : min(idx + n, length)]
|
|
||||||
|
|
||||||
for ids_batch, texts_batch, metadatas_batch in zip(
|
|
||||||
batch(ids, bulk_size),
|
|
||||||
batch(texts, bulk_size),
|
|
||||||
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
|
|
||||||
):
|
|
||||||
existing_ids_set = __fetch_existing_ids(ids_batch)
|
|
||||||
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
|
|
||||||
if doc_id not in existing_ids_set:
|
|
||||||
filtered_texts.append(text)
|
|
||||||
filtered_ids.append(doc_id)
|
|
||||||
if metadatas is not None:
|
|
||||||
filtered_metadatas.append(metadata)
|
|
||||||
|
|
||||||
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
|
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
actions = []
|
actions = []
|
||||||
uuids = self._get_uuids(documents)
|
uuids = self._get_uuids(documents)
|
||||||
for i in range(len(documents)):
|
for i in range(len(documents)):
|
||||||
action = {
|
action_header = {
|
||||||
"_op_type": "index",
|
"index": {
|
||||||
"_index": self._collection_name.lower(),
|
"_index": self.collection_name.lower(),
|
||||||
"_id": uuids[i],
|
"_id": uuids[i],
|
||||||
"_source": {
|
}
|
||||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
|
||||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
|
||||||
Field.METADATA_KEY.value: documents[i].metadata,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
actions.append(action)
|
action_values = {
|
||||||
bulk(self._client, actions)
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||||
self.refresh()
|
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||||
|
Field.METADATA_KEY.value: documents[i].metadata,
|
||||||
|
}
|
||||||
|
if self._using_ugc:
|
||||||
|
action_header["index"]["routing"] = self._routing
|
||||||
|
action_values[self._routing_field] = self._routing
|
||||||
|
actions.append(action_header)
|
||||||
|
actions.append(action_values)
|
||||||
|
response = self._client.bulk(actions)
|
||||||
|
if response["errors"]:
|
||||||
|
for item in response["items"]:
|
||||||
|
print(f"{item['index']['status']}: {item['index']['error']['type']}")
|
||||||
|
else:
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
|
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
|
||||||
|
if self._using_ugc:
|
||||||
|
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
|
||||||
response = self._client.search(index=self._collection_name, body=query)
|
response = self._client.search(index=self._collection_name, body=query)
|
||||||
if response["hits"]["hits"]:
|
if response["hits"]["hits"]:
|
||||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||||
@ -156,50 +115,62 @@ class LindormVectorStore(BaseVector):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
ids = self.get_ids_by_metadata_field(key, value)
|
||||||
results = self._client.search(index=self._collection_name, body=query_str)
|
|
||||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
|
||||||
if ids:
|
if ids:
|
||||||
self.delete_by_ids(ids)
|
self.delete_by_ids(ids)
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
params = {}
|
||||||
|
if self._using_ugc:
|
||||||
|
params["routing"] = self._routing
|
||||||
for id in ids:
|
for id in ids:
|
||||||
if self._client.exists(index=self._collection_name, id=id):
|
if self._client.exists(index=self._collection_name, id=id, params=params):
|
||||||
self._client.delete(index=self._collection_name, id=id)
|
params = {}
|
||||||
|
if self._using_ugc:
|
||||||
|
params["routing"] = self._routing
|
||||||
|
self._client.delete(index=self._collection_name, id=id, params=params)
|
||||||
|
self.refresh()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
try:
|
if self._using_ugc:
|
||||||
|
routing_filter_query = {
|
||||||
|
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
|
||||||
|
}
|
||||||
|
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
|
||||||
|
self.refresh()
|
||||||
|
else:
|
||||||
if self._client.indices.exists(index=self._collection_name):
|
if self._client.indices.exists(index=self._collection_name):
|
||||||
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
||||||
logger.info("Delete index success")
|
logger.info("Delete index success")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
self._client.get(index=self._collection_name, id=id)
|
params = {}
|
||||||
|
if self._using_ugc:
|
||||||
|
params["routing"] = self._routing
|
||||||
|
self._client.get(index=self._collection_name, id=id, params=params)
|
||||||
return True
|
return True
|
||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
# Make sure query_vector is a list
|
|
||||||
if not isinstance(query_vector, list):
|
if not isinstance(query_vector, list):
|
||||||
raise ValueError("query_vector should be a list of floats")
|
raise ValueError("query_vector should be a list of floats")
|
||||||
|
|
||||||
# Check whether query_vector is a floating-point number list
|
|
||||||
if not all(isinstance(x, float) for x in query_vector):
|
if not all(isinstance(x, float) for x in query_vector):
|
||||||
raise ValueError("All elements in query_vector should be floats")
|
raise ValueError("All elements in query_vector should be floats")
|
||||||
|
|
||||||
top_k = kwargs.get("top_k", 10)
|
top_k = kwargs.get("top_k", 10)
|
||||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||||
try:
|
try:
|
||||||
response = self._client.search(index=self._collection_name, body=query)
|
params = {}
|
||||||
|
if self._using_ugc:
|
||||||
|
params["routing"] = self._routing
|
||||||
|
response = self._client.search(index=self._collection_name, body=query, params=params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error executing vector search, query: {query}")
|
logger.exception(f"Error executing vector search, query: {query}")
|
||||||
raise
|
raise
|
||||||
@ -232,7 +203,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||||
top_k = kwargs.get("top_k", 10)
|
top_k = kwargs.get("top_k", 10)
|
||||||
filters = kwargs.get("filter")
|
filters = kwargs.get("filter")
|
||||||
routing = kwargs.get("routing")
|
routing = self._routing
|
||||||
full_text_query = default_text_search_query(
|
full_text_query = default_text_search_query(
|
||||||
query_text=query,
|
query_text=query,
|
||||||
k=top_k,
|
k=top_k,
|
||||||
@ -243,6 +214,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
minimum_should_match=minimum_should_match,
|
minimum_should_match=minimum_should_match,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
routing=routing,
|
routing=routing,
|
||||||
|
routing_field=self._routing_field,
|
||||||
)
|
)
|
||||||
response = self._client.search(index=self._collection_name, body=full_text_query)
|
response = self._client.search(index=self._collection_name, body=full_text_query)
|
||||||
docs = []
|
docs = []
|
||||||
@ -265,17 +237,18 @@ class LindormVectorStore(BaseVector):
|
|||||||
logger.info(f"Collection {self._collection_name} already exists.")
|
logger.info(f"Collection {self._collection_name} already exists.")
|
||||||
return
|
return
|
||||||
if self._client.indices.exists(index=self._collection_name):
|
if self._client.indices.exists(index=self._collection_name):
|
||||||
logger.info("{self._collection_name.lower()} already exists.")
|
logger.info(f"{self._collection_name.lower()} already exists.")
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
return
|
return
|
||||||
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
||||||
self.kwargs = copy.deepcopy(kwargs)
|
self.kwargs = copy.deepcopy(kwargs)
|
||||||
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
||||||
shards = kwargs.pop("shards", 2)
|
shards = kwargs.pop("shards", 4)
|
||||||
|
|
||||||
engine = kwargs.pop("engine", "lvector")
|
engine = kwargs.pop("engine", "lvector")
|
||||||
method_name = kwargs.pop("method_name", "hnsw")
|
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
|
||||||
|
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
|
||||||
data_type = kwargs.pop("data_type", "float")
|
data_type = kwargs.pop("data_type", "float")
|
||||||
space_type = kwargs.pop("space_type", "cosinesimil")
|
|
||||||
|
|
||||||
hnsw_m = kwargs.pop("hnsw_m", 24)
|
hnsw_m = kwargs.pop("hnsw_m", 24)
|
||||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||||
@ -288,10 +261,10 @@ class LindormVectorStore(BaseVector):
|
|||||||
mapping = default_text_mapping(
|
mapping = default_text_mapping(
|
||||||
dimension,
|
dimension,
|
||||||
method_name,
|
method_name,
|
||||||
|
space_type=space_type,
|
||||||
shards=shards,
|
shards=shards,
|
||||||
engine=engine,
|
engine=engine,
|
||||||
data_type=data_type,
|
data_type=data_type,
|
||||||
space_type=space_type,
|
|
||||||
vector_field=vector_field,
|
vector_field=vector_field,
|
||||||
hnsw_m=hnsw_m,
|
hnsw_m=hnsw_m,
|
||||||
hnsw_ef_construction=hnsw_ef_construction,
|
hnsw_ef_construction=hnsw_ef_construction,
|
||||||
@ -301,6 +274,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
centroids_hnsw_m=centroids_hnsw_m,
|
centroids_hnsw_m=centroids_hnsw_m,
|
||||||
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
||||||
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
||||||
|
using_ugc=self._using_ugc,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
||||||
@ -309,15 +283,20 @@ class LindormVectorStore(BaseVector):
|
|||||||
|
|
||||||
|
|
||||||
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
|
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
|
||||||
routing_field = kwargs.get("routing_field")
|
|
||||||
excludes_from_source = kwargs.get("excludes_from_source")
|
excludes_from_source = kwargs.get("excludes_from_source")
|
||||||
analyzer = kwargs.get("analyzer", "ik_max_word")
|
analyzer = kwargs.get("analyzer", "ik_max_word")
|
||||||
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
||||||
engine = kwargs["engine"]
|
engine = kwargs["engine"]
|
||||||
shard = kwargs["shards"]
|
shard = kwargs["shards"]
|
||||||
space_type = kwargs["space_type"]
|
space_type = kwargs.get("space_type")
|
||||||
|
if space_type is None:
|
||||||
|
if method_name == "hnsw":
|
||||||
|
space_type = "l2"
|
||||||
|
else:
|
||||||
|
space_type = "cosine"
|
||||||
data_type = kwargs["data_type"]
|
data_type = kwargs["data_type"]
|
||||||
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
||||||
|
using_ugc = kwargs.get("using_ugc", False)
|
||||||
|
|
||||||
if method_name == "ivfpq":
|
if method_name == "ivfpq":
|
||||||
ivfpq_m = kwargs["ivfpq_m"]
|
ivfpq_m = kwargs["ivfpq_m"]
|
||||||
@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
|
|||||||
if excludes_from_source:
|
if excludes_from_source:
|
||||||
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
||||||
|
|
||||||
if method_name == "ivfpq" and routing_field is not None:
|
if using_ugc and method_name == "ivfpq":
|
||||||
mapping["settings"]["index"]["knn_routing"] = True
|
mapping["settings"]["index"]["knn_routing"] = True
|
||||||
mapping["settings"]["index"]["knn.offline.construction"] = True
|
mapping["settings"]["index"]["knn.offline.construction"] = True
|
||||||
|
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
|
||||||
if method_name == "flat" and routing_field is not None:
|
|
||||||
mapping["settings"]["index"]["knn_routing"] = True
|
mapping["settings"]["index"]["knn_routing"] = True
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
@ -386,14 +363,12 @@ def default_text_search_query(
|
|||||||
minimum_should_match: int = 0,
|
minimum_should_match: int = 0,
|
||||||
filters: Optional[list[dict]] = None,
|
filters: Optional[list[dict]] = None,
|
||||||
routing: Optional[str] = None,
|
routing: Optional[str] = None,
|
||||||
|
routing_field: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if routing is not None:
|
if routing is not None:
|
||||||
routing_field = kwargs.get("routing_field", "routing_field")
|
|
||||||
query_clause = {
|
query_clause = {
|
||||||
"bool": {
|
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
|
||||||
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
query_clause = {"match": {text_field: query_text}}
|
query_clause = {"match": {text_field: query_text}}
|
||||||
@ -483,16 +458,40 @@ def default_vector_search_query(
|
|||||||
|
|
||||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||||
if dataset.index_struct_dict:
|
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
||||||
collection_name = class_prefix
|
|
||||||
else:
|
|
||||||
dataset_id = dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
|
|
||||||
lindorm_config = LindormVectorStoreConfig(
|
lindorm_config = LindormVectorStoreConfig(
|
||||||
hosts=dify_config.LINDORM_URL,
|
hosts=dify_config.LINDORM_URL,
|
||||||
username=dify_config.LINDORM_USERNAME,
|
username=dify_config.LINDORM_USERNAME,
|
||||||
password=dify_config.LINDORM_PASSWORD,
|
password=dify_config.LINDORM_PASSWORD,
|
||||||
|
using_ugc=dify_config.USING_UGC_INDEX,
|
||||||
)
|
)
|
||||||
return LindormVectorStore(collection_name, lindorm_config)
|
using_ugc = dify_config.USING_UGC_INDEX
|
||||||
|
routing_value = None
|
||||||
|
if dataset.index_struct:
|
||||||
|
if using_ugc:
|
||||||
|
dimension = dataset.index_struct_dict["dimension"]
|
||||||
|
index_type = dataset.index_struct_dict["index_type"]
|
||||||
|
distance_type = dataset.index_struct_dict["distance_type"]
|
||||||
|
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
||||||
|
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
else:
|
||||||
|
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
else:
|
||||||
|
embedding_vector = embeddings.embed_query("hello word")
|
||||||
|
dimension = len(embedding_vector)
|
||||||
|
index_type = dify_config.DEFAULT_INDEX_TYPE
|
||||||
|
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
|
||||||
|
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
|
||||||
|
index_struct_dict = {
|
||||||
|
"type": VectorType.LINDORM,
|
||||||
|
"vector_store": {"class_prefix": class_prefix},
|
||||||
|
"index_type": index_type,
|
||||||
|
"dimension": dimension,
|
||||||
|
"distance_type": distance_type,
|
||||||
|
}
|
||||||
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
if using_ugc:
|
||||||
|
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
||||||
|
routing_value = class_prefix
|
||||||
|
else:
|
||||||
|
index_name = class_prefix
|
||||||
|
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)
|
||||||
|
@ -7,9 +7,10 @@ env = environs.Env()
|
|||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070")
|
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
|
||||||
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
|
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
|
||||||
SEARCH_PWD = env.str("SEARCH_PWD", "PWD")
|
SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
|
||||||
|
USING_UGC = env.bool("USING_UGC", True)
|
||||||
|
|
||||||
|
|
||||||
class TestLindormVectorStore(AbstractVectorTest):
|
class TestLindormVectorStore(AbstractVectorTest):
|
||||||
@ -31,5 +32,27 @@ class TestLindormVectorStore(AbstractVectorTest):
|
|||||||
assert ids[0] == self.example_doc_id
|
assert ids[0] == self.example_doc_id
|
||||||
|
|
||||||
|
|
||||||
def test_lindorm_vector(setup_mock_redis):
|
class TestLindormVectorStoreUGC(AbstractVectorTest):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.vector = LindormVectorStore(
|
||||||
|
collection_name="ugc_index_test",
|
||||||
|
config=LindormVectorStoreConfig(
|
||||||
|
hosts=Config.SEARCH_ENDPOINT,
|
||||||
|
username=Config.SEARCH_USERNAME,
|
||||||
|
password=Config.SEARCH_PWD,
|
||||||
|
using_ugc=Config.USING_UGC,
|
||||||
|
),
|
||||||
|
routing_value=self.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ids_by_metadata_field(self):
|
||||||
|
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||||
|
assert ids is not None
|
||||||
|
assert len(ids) == 1
|
||||||
|
assert ids[0] == self.example_doc_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_lindorm_vector_ugc(setup_mock_redis):
|
||||||
TestLindormVectorStore().run_all_tests()
|
TestLindormVectorStore().run_all_tests()
|
||||||
|
TestLindormVectorStoreUGC().run_all_tests()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user