mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 03:55:52 +08:00
feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search (#7641)
Co-authored-by: haokai <haokai@shuwen.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Bowen Liang <bowenliang@apache.org> Co-authored-by: wellCh4n <wellCh4n@foxmail.com>
This commit is contained in:
parent
e7afee1176
commit
122ce41020
@ -13,6 +13,7 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
|||||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||||
|
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||||
@ -200,5 +201,6 @@ class MiddlewareConfig(
|
|||||||
TencentVectorDBConfig,
|
TencentVectorDBConfig,
|
||||||
TiDBVectorConfig,
|
TiDBVectorConfig,
|
||||||
WeaviateConfig,
|
WeaviateConfig,
|
||||||
|
ElasticsearchConfig,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
30
api/configs/middleware/vdb/elasticsearch_config.py
Normal file
30
api/configs/middleware/vdb/elasticsearch_config.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Elasticsearch configs
|
||||||
|
"""
|
||||||
|
|
||||||
|
ELASTICSEARCH_HOST: Optional[str] = Field(
|
||||||
|
description="Elasticsearch host",
|
||||||
|
default="127.0.0.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
ELASTICSEARCH_PORT: PositiveInt = Field(
|
||||||
|
description="Elasticsearch port",
|
||||||
|
default=9200,
|
||||||
|
)
|
||||||
|
|
||||||
|
ELASTICSEARCH_USERNAME: Optional[str] = Field(
|
||||||
|
description="Elasticsearch username",
|
||||||
|
default="elastic",
|
||||||
|
)
|
||||||
|
|
||||||
|
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
|
||||||
|
description="Elasticsearch password",
|
||||||
|
default="elastic",
|
||||||
|
)
|
@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
@ -7,16 +9,20 @@ from flask import current_app
|
|||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from core.rag.datasource.entity.embedding import Embeddings
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ElasticSearchConfig(BaseModel):
|
class ElasticSearchConfig(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
port: str
|
port: int
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector):
|
|||||||
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
|
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
|
||||||
super().__init__(index_name.lower())
|
super().__init__(index_name.lower())
|
||||||
self._client = self._init_client(config)
|
self._client = self._init_client(config)
|
||||||
|
self._version = self._get_version()
|
||||||
|
self._check_version()
|
||||||
self._attributes = attributes
|
self._attributes = attributes
|
||||||
|
|
||||||
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
|
||||||
try:
|
try:
|
||||||
|
parsed_url = urlparse(config.host)
|
||||||
|
if parsed_url.scheme in ['http', 'https']:
|
||||||
|
hosts = f'{config.host}:{config.port}'
|
||||||
|
else:
|
||||||
|
hosts = f'http://{config.host}:{config.port}'
|
||||||
client = Elasticsearch(
|
client = Elasticsearch(
|
||||||
hosts=f'{config.host}:{config.port}',
|
hosts=hosts,
|
||||||
basic_auth=(config.username, config.password),
|
basic_auth=(config.username, config.password),
|
||||||
request_timeout=100000,
|
request_timeout=100000,
|
||||||
retry_on_timeout=True,
|
retry_on_timeout=True,
|
||||||
@ -53,42 +66,27 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
def _get_version(self) -> str:
|
||||||
|
info = self._client.info()
|
||||||
|
return info['version']['number']
|
||||||
|
|
||||||
|
def _check_version(self):
|
||||||
|
if self._version < '8.0.0':
|
||||||
|
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'elasticsearch'
|
return 'elasticsearch'
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
uuids = self._get_uuids(documents)
|
uuids = self._get_uuids(documents)
|
||||||
texts = [d.page_content for d in documents]
|
for i in range(len(documents)):
|
||||||
metadatas = [d.metadata for d in documents]
|
|
||||||
|
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
|
||||||
dim = len(embeddings[0])
|
|
||||||
mapping = {
|
|
||||||
"properties": {
|
|
||||||
"text": {
|
|
||||||
"type": "text"
|
|
||||||
},
|
|
||||||
"vector": {
|
|
||||||
"type": "dense_vector",
|
|
||||||
"index": True,
|
|
||||||
"dims": dim,
|
|
||||||
"similarity": "l2_norm"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self._client.indices.create(index=self._collection_name, mappings=mapping)
|
|
||||||
|
|
||||||
added_ids = []
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
self._client.index(index=self._collection_name,
|
self._client.index(index=self._collection_name,
|
||||||
id=uuids[i],
|
id=uuids[i],
|
||||||
document={
|
document={
|
||||||
"text": text,
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||||
"vector": embeddings[i] if embeddings[i] else None,
|
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||||
"metadata": metadatas[i] if metadatas[i] else {},
|
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
|
||||||
})
|
})
|
||||||
added_ids.append(uuids[i])
|
|
||||||
|
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
self._client.indices.refresh(index=self._collection_name)
|
||||||
return uuids
|
return uuids
|
||||||
|
|
||||||
@ -116,28 +114,21 @@ class ElasticSearchVector(BaseVector):
|
|||||||
self._client.indices.delete(index=self._collection_name)
|
self._client.indices.delete(index=self._collection_name)
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
query_str = {
|
top_k = kwargs.get("top_k", 10)
|
||||||
"query": {
|
knn = {
|
||||||
"script_score": {
|
"field": Field.VECTOR.value,
|
||||||
"query": {
|
"query_vector": query_vector,
|
||||||
"match_all": {}
|
"k": top_k
|
||||||
},
|
|
||||||
"script": {
|
|
||||||
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
|
|
||||||
"params": {
|
|
||||||
"query_vector": query_vector
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
results = self._client.search(index=self._collection_name, body=query_str)
|
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||||
|
|
||||||
docs_and_scores = []
|
docs_and_scores = []
|
||||||
for hit in results['hits']['hits']:
|
for hit in results['hits']['hits']:
|
||||||
docs_and_scores.append(
|
docs_and_scores.append(
|
||||||
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
|
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||||
|
vector=hit['_source'][Field.VECTOR.value],
|
||||||
|
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for doc, score in docs_and_scores:
|
for doc, score in docs_and_scores:
|
||||||
@ -146,25 +137,61 @@ class ElasticSearchVector(BaseVector):
|
|||||||
doc.metadata['score'] = score
|
doc.metadata['score'] = score
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
|
|
||||||
# Sort the documents by score in descending order
|
|
||||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
query_str = {
|
query_str = {
|
||||||
"match": {
|
"match": {
|
||||||
"text": query
|
Field.CONTENT_KEY.value: query
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
results = self._client.search(index=self._collection_name, query=query_str)
|
results = self._client.search(index=self._collection_name, query=query_str)
|
||||||
docs = []
|
docs = []
|
||||||
for hit in results['hits']['hits']:
|
for hit in results['hits']['hits']:
|
||||||
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
|
docs.append(Document(
|
||||||
|
page_content=hit['_source'][Field.CONTENT_KEY.value],
|
||||||
|
vector=hit['_source'][Field.VECTOR.value],
|
||||||
|
metadata=hit['_source'][Field.METADATA_KEY.value],
|
||||||
|
))
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
return self.add_texts(texts, embeddings, **kwargs)
|
metadatas = [d.metadata for d in texts]
|
||||||
|
self.create_collection(embeddings, metadatas)
|
||||||
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
|
def create_collection(
|
||||||
|
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||||
|
):
|
||||||
|
lock_name = f'vector_indexing_lock_{self._collection_name}'
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
logger.info(f"Collection {self._collection_name} already exists.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
mappings = {
|
||||||
|
"properties": {
|
||||||
|
Field.CONTENT_KEY.value: {"type": "text"},
|
||||||
|
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": dim,
|
||||||
|
"similarity": "cosine"
|
||||||
|
},
|
||||||
|
Field.METADATA_KEY.value: {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self._client.indices.create(index=self._collection_name, mappings=mappings)
|
||||||
|
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
|
||||||
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user