fix: using api can not execute relyt vector database (#3766)

Co-authored-by: jingsi <jingsi@leadincloud.com>
This commit is contained in:
Jingpan Xiong 2024-04-25 19:46:20 +08:00 committed by GitHub
parent bf9fc8fef4
commit 1be222af2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 186 additions and 50 deletions

View File

@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
vector_type = current_app.config['VECTOR_STORE'] vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus': if vector_type == 'milvus' or vector_type == 'relyt':
return { return {
'retrieval_method': [ 'retrieval_method': [
'semantic_search' 'semantic_search'
@ -498,7 +498,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, vector_type): def get(self, vector_type):
if vector_type == 'milvus': if vector_type == 'milvus' or vector_type == 'relyt':
return { return {
'retrieval_method': [ 'retrieval_method': [
'semantic_search' 'semantic_search'

View File

@ -1,16 +1,23 @@
import logging import uuid
from typing import Any from typing import Any, Optional
from pgvecto_rs.sdk import PGVectoRs, Record
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__) Base = declarative_base() # type: Any
class RelytConfig(BaseModel): class RelytConfig(BaseModel):
host: str host: str
@ -36,16 +43,14 @@ class RelytConfig(BaseModel):
class RelytVector(BaseVector): class RelytVector(BaseVector):
def __init__(self, collection_name: str, config: RelytConfig, dim: int): def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
super().__init__(collection_name) super().__init__(collection_name)
self.embedding_dimension = 1536
self._client_config = config self._client_config = config
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
self._client = PGVectoRs( self.client = create_engine(self._url)
db_url=self._url,
collection_name=self._collection_name,
dimension=dim
)
self._fields = [] self._fields = []
self._group_id = group_id
def get_type(self) -> str: def get_type(self) -> str:
return 'relyt' return 'relyt'
@ -54,6 +59,7 @@ class RelytVector(BaseVector):
index_params = {} index_params = {}
metadatas = [d.metadata for d in texts] metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0])) self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
self.add_texts(texts, embeddings) self.add_texts(texts, embeddings)
def create_collection(self, dimension: int): def create_collection(self, dimension: int):
@ -63,21 +69,21 @@ class RelytVector(BaseVector):
if redis_client.get(collection_exist_cache_key): if redis_client.get(collection_exist_cache_key):
return return
index_name = f"{self._collection_name}_embedding_index" index_name = f"{self._collection_name}_embedding_index"
with Session(self._client._engine) as session: with Session(self.client) as session:
drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}") drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
session.execute(drop_statement) session.execute(drop_statement)
create_statement = sql_text(f""" create_statement = sql_text(f"""
CREATE TABLE IF NOT EXISTS collection_{self._collection_name} ( CREATE TABLE IF NOT EXISTS "{self._collection_name}" (
id UUID PRIMARY KEY, id TEXT PRIMARY KEY,
text TEXT NOT NULL, document TEXT NOT NULL,
meta JSONB NOT NULL, metadata JSON NOT NULL,
embedding vector({dimension}) NOT NULL embedding vector({dimension}) NOT NULL
) using heap; ) using heap;
""") """)
session.execute(create_statement) session.execute(create_statement)
index_statement = sql_text(f""" index_statement = sql_text(f"""
CREATE INDEX {index_name} CREATE INDEX {index_name}
ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops) ON "{self._collection_name}" USING vectors(embedding vector_l2_ops)
WITH (options = $$ WITH (options = $$
optimizing.optimizing_threads = 30 optimizing.optimizing_threads = 30
segment.max_growing_segment_size = 2000 segment.max_growing_segment_size = 2000
@ -92,21 +98,62 @@ class RelytVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)] from pgvecto_rs.sqlalchemy import Vector
pks = [str(r.id) for r in records]
self._client.insert(records) ids = [str(uuid.uuid1()) for _ in documents]
return pks metadatas = [d.metadata for d in documents]
for metadata in metadatas:
metadata['group_id'] = self._group_id
texts = [d.page_content for d in documents]
# Define the table schema
chunks_table = Table(
self._collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", Vector(len(embeddings[0]))),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
chunks_table_data = []
with self.client.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == 500:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
return ids
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id) ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids: if ids:
self._client.delete_by_ids(ids) self.delete_by_uuids(ids)
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
result = None result = None
with Session(self._client._engine) as session: with Session(self.client) as session:
select_statement = sql_text( select_statement = sql_text(
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; " f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
) )
result = session.execute(select_statement).fetchall() result = session.execute(select_statement).fetchall()
if result: if result:
@ -114,56 +161,140 @@ class RelytVector(BaseVector):
else: else:
return None return None
def delete_by_uuids(self, ids: list[str] = None):
"""Delete by vector IDs.
Args:
ids: List of ids to delete.
"""
from pgvecto_rs.sqlalchemy import Vector
if ids is None:
raise ValueError("No ids provided to delete.")
# Define the table schema
chunks_table = Table(
self._collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", Vector(self.embedding_dimension)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
try:
with self.client.connect() as conn:
with conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e)) # noqa: T201
return False
def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value) ids = self.get_ids_by_metadata_field(key, value)
if ids: if ids:
self._client.delete_by_ids(ids) self.delete_by_uuids(ids)
def delete_by_ids(self, doc_ids: list[str]) -> None: def delete_by_ids(self, doc_ids: list[str]) -> None:
with Session(self._client._engine) as session:
with Session(self.client) as session:
ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids)
select_statement = sql_text( select_statement = sql_text(
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); " f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
) )
result = session.execute(select_statement).fetchall() result = session.execute(select_statement).fetchall()
if result: if result:
ids = [item[0] for item in result] ids = [item[0] for item in result]
self._client.delete_by_ids(ids) self.delete_by_uuids(ids)
def delete(self) -> None: def delete(self) -> None:
with Session(self._client._engine) as session: with Session(self.client) as session:
session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")) session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
session.commit() session.commit()
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
with Session(self._client._engine) as session: with Session(self.client) as session:
select_statement = sql_text( select_statement = sql_text(
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
) )
result = session.execute(select_statement).fetchall() result = session.execute(select_statement).fetchall()
return len(result) > 0 return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from pgvecto_rs.sdk import filters results = self.similarity_search_with_score_by_vector(
filter_condition = filters.meta_contains(kwargs.get('filter')) k=int(kwargs.get('top_k')),
results = self._client.search(
top_k=int(kwargs.get('top_k')),
embedding=query_vector, embedding=query_vector,
filter=filter_condition filter=kwargs.get('filter')
) )
# Organize results. # Organize results.
docs = [] docs = []
for record, dis in results: for document, score in results:
metadata = record.meta
metadata['score'] = dis
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if dis > score_threshold: if score > score_threshold:
doc = Document(page_content=record.text, docs.append(document)
metadata=metadata)
docs.append(doc)
return docs return docs
def similarity_search_with_score_by_vector(
self,
embedding: list[float],
k: int = 4,
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"
# Define the base query
sql_query = f"""
set vectors.enable_search_growing = on;
set vectors.enable_search_write = on;
SELECT document, metadata, embedding <-> :embedding as distance
FROM "{self._collection_name}"
{filter_condition}
ORDER BY embedding <-> :embedding
LIMIT :k
"""
# Set up the query parameters
embedding_str = ", ".join(format(x) for x in embedding)
embedding_str = "[" + embedding_str + "]"
params = {"embedding": embedding_str, "k": k}
# Execute the query and fetch the results
with self.client.connect() as conn:
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
Document(
page_content=result.document,
metadata=result.metadata,
),
result.distance,
)
for result in results
]
return documents_with_scores
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz/relyt doesn't support bm25 search # milvus/zilliz/relyt doesn't support bm25 search
return [] return []

View File

@ -126,7 +126,6 @@ class Vector:
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
self._dataset.index_struct = json.dumps(index_struct_dict) self._dataset.index_struct = json.dumps(index_struct_dict)
dim = len(self._embeddings.embed_query("hello relyt"))
return RelytVector( return RelytVector(
collection_name=collection_name, collection_name=collection_name,
config=RelytConfig( config=RelytConfig(
@ -136,7 +135,7 @@ class Vector:
password=config.get('RELYT_PASSWORD'), password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'), database=config.get('RELYT_DATABASE'),
), ),
dim=dim group_id=self._dataset.id
) )
else: else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

View File

@ -86,7 +86,7 @@ services:
AZURE_BLOB_ACCOUNT_KEY: 'difyai' AZURE_BLOB_ACCOUNT_KEY: 'difyai'
AZURE_BLOB_CONTAINER_NAME: 'difyai-container' AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net' AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`. # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
VECTOR_STORE: weaviate VECTOR_STORE: weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT: http://weaviate:8080 WEAVIATE_ENDPOINT: http://weaviate:8080
@ -109,6 +109,12 @@ services:
MILVUS_PASSWORD: Milvus MILVUS_PASSWORD: Milvus
# The milvus tls switch. # The milvus tls switch.
MILVUS_SECURE: 'false' MILVUS_SECURE: 'false'
# relyt configurations
RELYT_HOST: db
RELYT_PORT: 5432
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
# Mail configuration, support: resend, smtp # Mail configuration, support: resend, smtp
MAIL_TYPE: '' MAIL_TYPE: ''
# default send from email address, if not specified # default send from email address, if not specified
@ -193,7 +199,7 @@ services:
AZURE_BLOB_ACCOUNT_KEY: 'difyai' AZURE_BLOB_ACCOUNT_KEY: 'difyai'
AZURE_BLOB_CONTAINER_NAME: 'difyai-container' AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net' AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`. # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
VECTOR_STORE: weaviate VECTOR_STORE: weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT: http://weaviate:8080 WEAVIATE_ENDPOINT: http://weaviate:8080