feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search (#7641)

Co-authored-by: haokai <haokai@shuwen.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
Co-authored-by: wellCh4n <wellCh4n@foxmail.com>
This commit is contained in:
Kenn 2024-08-27 11:43:44 +08:00 committed by GitHub
parent e7afee1176
commit 122ce41020
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 111 additions and 52 deletions

View File

@ -13,6 +13,7 @@ from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig
@ -200,5 +201,6 @@ class MiddlewareConfig(
TencentVectorDBConfig, TencentVectorDBConfig,
TiDBVectorConfig, TiDBVectorConfig,
WeaviateConfig, WeaviateConfig,
ElasticsearchConfig,
): ):
pass pass

View File

@ -0,0 +1,30 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)

View File

@ -1,5 +1,7 @@
import json import json
from typing import Any import logging
from typing import Any, Optional
from urllib.parse import urlparse
import requests import requests
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@ -7,16 +9,20 @@ from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
logger = logging.getLogger(__name__)
class ElasticSearchConfig(BaseModel): class ElasticSearchConfig(BaseModel):
host: str host: str
port: str port: int
username: str username: str
password: str password: str
@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower()) super().__init__(index_name.lower())
self._client = self._init_client(config) self._client = self._init_client(config)
self._version = self._get_version()
self._check_version()
self._attributes = attributes self._attributes = attributes
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try: try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
else:
hosts = f'http://{config.host}:{config.port}'
client = Elasticsearch( client = Elasticsearch(
hosts=f'{config.host}:{config.port}', hosts=hosts,
basic_auth=(config.username, config.password), basic_auth=(config.username, config.password),
request_timeout=100000, request_timeout=100000,
retry_on_timeout=True, retry_on_timeout=True,
@ -53,42 +66,27 @@ class ElasticSearchVector(BaseVector):
return client return client
def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']
def _check_version(self):
if self._version < '8.0.0':
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str: def get_type(self) -> str:
return 'elasticsearch' return 'elasticsearch'
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents] for i in range(len(documents)):
metadatas = [d.metadata for d in documents]
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)
added_ids = []
for i, text in enumerate(texts):
self._client.index(index=self._collection_name, self._client.index(index=self._collection_name,
id=uuids[i], id=uuids[i],
document={ document={
"text": text, Field.CONTENT_KEY.value: documents[i].page_content,
"vector": embeddings[i] if embeddings[i] else None, Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {}, Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
}) })
added_ids.append(uuids[i])
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)
return uuids return uuids
@ -116,28 +114,21 @@ class ElasticSearchVector(BaseVector):
self._client.indices.delete(index=self._collection_name) self._client.indices.delete(index=self._collection_name)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = { top_k = kwargs.get("top_k", 10)
"query": { knn = {
"script_score": { "field": Field.VECTOR.value,
"query": { "query_vector": query_vector,
"match_all": {} "k": top_k
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
} }
results = self._client.search(index=self._collection_name, body=query_str) results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
docs_and_scores = [] docs_and_scores = []
for hit in results['hits']['hits']: for hit in results['hits']['hits']:
docs_and_scores.append( docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) (Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))
docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
@ -146,25 +137,61 @@ class ElasticSearchVector(BaseVector):
doc.metadata['score'] = score doc.metadata['score'] = score
docs.append(doc) docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = { query_str = {
"match": { "match": {
"text": query Field.CONTENT_KEY.value: query
} }
} }
results = self._client.search(index=self._collection_name, query=query_str) results = self._client.search(index=self._collection_name, query=query_str)
docs = [] docs = []
for hit in results['hits']['hits']: for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))
return docs return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs) metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class ElasticSearchVectorFactory(AbstractVectorFactory): class ElasticSearchVectorFactory(AbstractVectorFactory):