mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 21:09:05 +08:00
fix: using api can not execute relyt vector database (#3766)
Co-authored-by: jingsi <jingsi@leadincloud.com>
This commit is contained in:
parent
bf9fc8fef4
commit
1be222af2e
@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
vector_type = current_app.config['VECTOR_STORE']
|
vector_type = current_app.config['VECTOR_STORE']
|
||||||
if vector_type == 'milvus':
|
if vector_type == 'milvus' or vector_type == 'relyt':
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
'semantic_search'
|
'semantic_search'
|
||||||
@ -498,7 +498,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, vector_type):
|
def get(self, vector_type):
|
||||||
|
|
||||||
if vector_type == 'milvus':
|
if vector_type == 'milvus' or vector_type == 'relyt':
|
||||||
return {
|
return {
|
||||||
'retrieval_method': [
|
'retrieval_method': [
|
||||||
'semantic_search'
|
'semantic_search'
|
||||||
|
@ -1,16 +1,23 @@
|
|||||||
import logging
|
import uuid
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pgvecto_rs.sdk import PGVectoRs, Record
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||||
from sqlalchemy import text as sql_text
|
from sqlalchemy import text as sql_text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy.orm import declarative_base
|
||||||
|
except ImportError:
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
|
||||||
class RelytConfig(BaseModel):
|
class RelytConfig(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
@ -36,16 +43,14 @@ class RelytConfig(BaseModel):
|
|||||||
|
|
||||||
class RelytVector(BaseVector):
|
class RelytVector(BaseVector):
|
||||||
|
|
||||||
def __init__(self, collection_name: str, config: RelytConfig, dim: int):
|
def __init__(self, collection_name: str, config: RelytConfig, group_id: str):
|
||||||
super().__init__(collection_name)
|
super().__init__(collection_name)
|
||||||
|
self.embedding_dimension = 1536
|
||||||
self._client_config = config
|
self._client_config = config
|
||||||
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||||
self._client = PGVectoRs(
|
self.client = create_engine(self._url)
|
||||||
db_url=self._url,
|
|
||||||
collection_name=self._collection_name,
|
|
||||||
dimension=dim
|
|
||||||
)
|
|
||||||
self._fields = []
|
self._fields = []
|
||||||
|
self._group_id = group_id
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'relyt'
|
return 'relyt'
|
||||||
@ -54,6 +59,7 @@ class RelytVector(BaseVector):
|
|||||||
index_params = {}
|
index_params = {}
|
||||||
metadatas = [d.metadata for d in texts]
|
metadatas = [d.metadata for d in texts]
|
||||||
self.create_collection(len(embeddings[0]))
|
self.create_collection(len(embeddings[0]))
|
||||||
|
self.embedding_dimension = len(embeddings[0])
|
||||||
self.add_texts(texts, embeddings)
|
self.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def create_collection(self, dimension: int):
|
def create_collection(self, dimension: int):
|
||||||
@ -63,21 +69,21 @@ class RelytVector(BaseVector):
|
|||||||
if redis_client.get(collection_exist_cache_key):
|
if redis_client.get(collection_exist_cache_key):
|
||||||
return
|
return
|
||||||
index_name = f"{self._collection_name}_embedding_index"
|
index_name = f"{self._collection_name}_embedding_index"
|
||||||
with Session(self._client._engine) as session:
|
with Session(self.client) as session:
|
||||||
drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")
|
drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """)
|
||||||
session.execute(drop_statement)
|
session.execute(drop_statement)
|
||||||
create_statement = sql_text(f"""
|
create_statement = sql_text(f"""
|
||||||
CREATE TABLE IF NOT EXISTS collection_{self._collection_name} (
|
CREATE TABLE IF NOT EXISTS "{self._collection_name}" (
|
||||||
id UUID PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
text TEXT NOT NULL,
|
document TEXT NOT NULL,
|
||||||
meta JSONB NOT NULL,
|
metadata JSON NOT NULL,
|
||||||
embedding vector({dimension}) NOT NULL
|
embedding vector({dimension}) NOT NULL
|
||||||
) using heap;
|
) using heap;
|
||||||
""")
|
""")
|
||||||
session.execute(create_statement)
|
session.execute(create_statement)
|
||||||
index_statement = sql_text(f"""
|
index_statement = sql_text(f"""
|
||||||
CREATE INDEX {index_name}
|
CREATE INDEX {index_name}
|
||||||
ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops)
|
ON "{self._collection_name}" USING vectors(embedding vector_l2_ops)
|
||||||
WITH (options = $$
|
WITH (options = $$
|
||||||
optimizing.optimizing_threads = 30
|
optimizing.optimizing_threads = 30
|
||||||
segment.max_growing_segment_size = 2000
|
segment.max_growing_segment_size = 2000
|
||||||
@ -92,21 +98,62 @@ class RelytVector(BaseVector):
|
|||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)]
|
from pgvecto_rs.sqlalchemy import Vector
|
||||||
pks = [str(r.id) for r in records]
|
|
||||||
self._client.insert(records)
|
ids = [str(uuid.uuid1()) for _ in documents]
|
||||||
return pks
|
metadatas = [d.metadata for d in documents]
|
||||||
|
for metadata in metadatas:
|
||||||
|
metadata['group_id'] = self._group_id
|
||||||
|
texts = [d.page_content for d in documents]
|
||||||
|
|
||||||
|
# Define the table schema
|
||||||
|
chunks_table = Table(
|
||||||
|
self._collection_name,
|
||||||
|
Base.metadata,
|
||||||
|
Column("id", TEXT, primary_key=True),
|
||||||
|
Column("embedding", Vector(len(embeddings[0]))),
|
||||||
|
Column("document", String, nullable=True),
|
||||||
|
Column("metadata", JSON, nullable=True),
|
||||||
|
extend_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_table_data = []
|
||||||
|
with self.client.connect() as conn:
|
||||||
|
with conn.begin():
|
||||||
|
for document, metadata, chunk_id, embedding in zip(
|
||||||
|
texts, metadatas, ids, embeddings
|
||||||
|
):
|
||||||
|
chunks_table_data.append(
|
||||||
|
{
|
||||||
|
"id": chunk_id,
|
||||||
|
"embedding": embedding,
|
||||||
|
"document": document,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the batch insert when the batch size is reached
|
||||||
|
if len(chunks_table_data) == 500:
|
||||||
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
# Clear the chunks_table_data list for the next batch
|
||||||
|
chunks_table_data.clear()
|
||||||
|
|
||||||
|
# Insert any remaining records that didn't make up a full batch
|
||||||
|
if chunks_table_data:
|
||||||
|
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str):
|
def delete_by_document_id(self, document_id: str):
|
||||||
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
ids = self.get_ids_by_metadata_field('document_id', document_id)
|
||||||
if ids:
|
if ids:
|
||||||
self._client.delete_by_ids(ids)
|
self.delete_by_uuids(ids)
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
result = None
|
result = None
|
||||||
with Session(self._client._engine) as session:
|
with Session(self.client) as session:
|
||||||
select_statement = sql_text(
|
select_statement = sql_text(
|
||||||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; "
|
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
|
||||||
)
|
)
|
||||||
result = session.execute(select_statement).fetchall()
|
result = session.execute(select_statement).fetchall()
|
||||||
if result:
|
if result:
|
||||||
@ -114,56 +161,140 @@ class RelytVector(BaseVector):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def delete_by_uuids(self, ids: list[str] = None):
|
||||||
|
"""Delete by vector IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
"""
|
||||||
|
from pgvecto_rs.sqlalchemy import Vector
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
raise ValueError("No ids provided to delete.")
|
||||||
|
|
||||||
|
# Define the table schema
|
||||||
|
chunks_table = Table(
|
||||||
|
self._collection_name,
|
||||||
|
Base.metadata,
|
||||||
|
Column("id", TEXT, primary_key=True),
|
||||||
|
Column("embedding", Vector(self.embedding_dimension)),
|
||||||
|
Column("document", String, nullable=True),
|
||||||
|
Column("metadata", JSON, nullable=True),
|
||||||
|
extend_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.client.connect() as conn:
|
||||||
|
with conn.begin():
|
||||||
|
delete_condition = chunks_table.c.id.in_(ids)
|
||||||
|
conn.execute(chunks_table.delete().where(delete_condition))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print("Delete operation failed:", str(e)) # noqa: T201
|
||||||
|
return False
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
|
|
||||||
ids = self.get_ids_by_metadata_field(key, value)
|
ids = self.get_ids_by_metadata_field(key, value)
|
||||||
if ids:
|
if ids:
|
||||||
self._client.delete_by_ids(ids)
|
self.delete_by_uuids(ids)
|
||||||
|
|
||||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||||
with Session(self._client._engine) as session:
|
|
||||||
|
with Session(self.client) as session:
|
||||||
|
ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids)
|
||||||
select_statement = sql_text(
|
select_statement = sql_text(
|
||||||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); "
|
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||||
)
|
)
|
||||||
result = session.execute(select_statement).fetchall()
|
result = session.execute(select_statement).fetchall()
|
||||||
if result:
|
if result:
|
||||||
ids = [item[0] for item in result]
|
ids = [item[0] for item in result]
|
||||||
self._client.delete_by_ids(ids)
|
self.delete_by_uuids(ids)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
with Session(self._client._engine) as session:
|
with Session(self.client) as session:
|
||||||
session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}"))
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";"""))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
with Session(self._client._engine) as session:
|
with Session(self.client) as session:
|
||||||
select_statement = sql_text(
|
select_statement = sql_text(
|
||||||
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
|
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
|
||||||
)
|
)
|
||||||
result = session.execute(select_statement).fetchall()
|
result = session.execute(select_statement).fetchall()
|
||||||
return len(result) > 0
|
return len(result) > 0
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
from pgvecto_rs.sdk import filters
|
results = self.similarity_search_with_score_by_vector(
|
||||||
filter_condition = filters.meta_contains(kwargs.get('filter'))
|
k=int(kwargs.get('top_k')),
|
||||||
results = self._client.search(
|
|
||||||
top_k=int(kwargs.get('top_k')),
|
|
||||||
embedding=query_vector,
|
embedding=query_vector,
|
||||||
filter=filter_condition
|
filter=kwargs.get('filter')
|
||||||
)
|
)
|
||||||
|
|
||||||
# Organize results.
|
# Organize results.
|
||||||
docs = []
|
docs = []
|
||||||
for record, dis in results:
|
for document, score in results:
|
||||||
metadata = record.meta
|
|
||||||
metadata['score'] = dis
|
|
||||||
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
|
||||||
if dis > score_threshold:
|
if score > score_threshold:
|
||||||
doc = Document(page_content=record.text,
|
docs.append(document)
|
||||||
metadata=metadata)
|
|
||||||
docs.append(doc)
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
) -> list[tuple[Document, float]]:
|
||||||
|
# Add the filter if provided
|
||||||
|
try:
|
||||||
|
from sqlalchemy.engine import Row
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import Row from sqlalchemy.engine. "
|
||||||
|
"Please 'pip install sqlalchemy>=1.4'."
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_condition = ""
|
||||||
|
if filter is not None:
|
||||||
|
conditions = [
|
||||||
|
f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1
|
||||||
|
else f"metadata->>{key!r} = {value[0]!r}"
|
||||||
|
for key, value in filter.items()
|
||||||
|
]
|
||||||
|
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||||
|
|
||||||
|
# Define the base query
|
||||||
|
sql_query = f"""
|
||||||
|
set vectors.enable_search_growing = on;
|
||||||
|
set vectors.enable_search_write = on;
|
||||||
|
SELECT document, metadata, embedding <-> :embedding as distance
|
||||||
|
FROM "{self._collection_name}"
|
||||||
|
{filter_condition}
|
||||||
|
ORDER BY embedding <-> :embedding
|
||||||
|
LIMIT :k
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Set up the query parameters
|
||||||
|
embedding_str = ", ".join(format(x) for x in embedding)
|
||||||
|
embedding_str = "[" + embedding_str + "]"
|
||||||
|
params = {"embedding": embedding_str, "k": k}
|
||||||
|
|
||||||
|
# Execute the query and fetch the results
|
||||||
|
with self.client.connect() as conn:
|
||||||
|
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
|
||||||
|
|
||||||
|
documents_with_scores = [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=result.document,
|
||||||
|
metadata=result.metadata,
|
||||||
|
),
|
||||||
|
result.distance,
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
return documents_with_scores
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
# milvus/zilliz/relyt doesn't support bm25 search
|
# milvus/zilliz/relyt doesn't support bm25 search
|
||||||
return []
|
return []
|
||||||
|
@ -126,7 +126,6 @@ class Vector:
|
|||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name}
|
||||||
}
|
}
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
dim = len(self._embeddings.embed_query("hello relyt"))
|
|
||||||
return RelytVector(
|
return RelytVector(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
config=RelytConfig(
|
config=RelytConfig(
|
||||||
@ -136,7 +135,7 @@ class Vector:
|
|||||||
password=config.get('RELYT_PASSWORD'),
|
password=config.get('RELYT_PASSWORD'),
|
||||||
database=config.get('RELYT_DATABASE'),
|
database=config.get('RELYT_DATABASE'),
|
||||||
),
|
),
|
||||||
dim=dim
|
group_id=self._dataset.id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||||
|
@ -86,7 +86,7 @@ services:
|
|||||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
|
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
|
||||||
VECTOR_STORE: weaviate
|
VECTOR_STORE: weaviate
|
||||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||||
@ -109,6 +109,12 @@ services:
|
|||||||
MILVUS_PASSWORD: Milvus
|
MILVUS_PASSWORD: Milvus
|
||||||
# The milvus tls switch.
|
# The milvus tls switch.
|
||||||
MILVUS_SECURE: 'false'
|
MILVUS_SECURE: 'false'
|
||||||
|
# relyt configurations
|
||||||
|
RELYT_HOST: db
|
||||||
|
RELYT_PORT: 5432
|
||||||
|
RELYT_USER: postgres
|
||||||
|
RELYT_PASSWORD: difyai123456
|
||||||
|
RELYT_DATABASE: postgres
|
||||||
# Mail configuration, support: resend, smtp
|
# Mail configuration, support: resend, smtp
|
||||||
MAIL_TYPE: ''
|
MAIL_TYPE: ''
|
||||||
# default send from email address, if not specified
|
# default send from email address, if not specified
|
||||||
@ -193,7 +199,7 @@ services:
|
|||||||
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
AZURE_BLOB_ACCOUNT_KEY: 'difyai'
|
||||||
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
AZURE_BLOB_CONTAINER_NAME: 'difyai-container'
|
||||||
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
AZURE_BLOB_ACCOUNT_URL: 'https://<your_account_name>.blob.core.windows.net'
|
||||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`.
|
# The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`.
|
||||||
VECTOR_STORE: weaviate
|
VECTOR_STORE: weaviate
|
||||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||||
|
Loading…
x
Reference in New Issue
Block a user