mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-05-13 20:28:16 +08:00
feat: support vastbase vector database (#16308)
This commit is contained in:
parent
cd9e6609ad
commit
0babdffe3e
@ -271,6 +271,7 @@ def migrate_knowledge_vector_database():
|
||||
upper_collection_vector_types = {
|
||||
VectorType.MILVUS,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.RELYT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.ORACLE,
|
||||
|
@ -39,6 +39,7 @@ from .vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from .vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from .vdb.upstash_config import UpstashConfig
|
||||
from .vdb.vastbase_vector_config import VastbaseVectorConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
@ -270,6 +271,7 @@ class MiddlewareConfig(
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
PGVectorConfig,
|
||||
VastbaseVectorConfig,
|
||||
PGVectoRSConfig,
|
||||
QdrantConfig,
|
||||
RelytConfig,
|
||||
|
45
api/configs/middleware/vdb/vastbase_vector_config.py
Normal file
45
api/configs/middleware/vdb/vastbase_vector_config.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class VastbaseVectorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Vector (Vastbase with vector extension)
|
||||
"""
|
||||
|
||||
VASTBASE_HOST: Optional[str] = Field(
|
||||
description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_PORT: PositiveInt = Field(
|
||||
description="Port number on which the Vastbase server is listening (default is 5432)",
|
||||
default=5432,
|
||||
)
|
||||
|
||||
VASTBASE_USER: Optional[str] = Field(
|
||||
description="Username for authenticating with the Vastbase database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with the Vastbase database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_DATABASE: Optional[str] = Field(
|
||||
description="Name of the Vastbase database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VASTBASE_MIN_CONNECTION: PositiveInt = Field(
|
||||
description="Min connection of the Vastbase database",
|
||||
default=1,
|
||||
)
|
||||
|
||||
VASTBASE_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Max connection of the Vastbase database",
|
||||
default=5,
|
||||
)
|
@ -657,6 +657,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
@ -706,6 +707,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
|
0
api/core/rag/datasource/vdb/pyvastbase/__init__.py
Normal file
0
api/core/rag/datasource/vdb/pyvastbase/__init__.py
Normal file
243
api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
Normal file
243
api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
Normal file
@ -0,0 +1,243 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class VastbaseVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config VASTBASE_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config VASTBASE_PORT is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config VASTBASE_USER is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config VASTBASE_PASSWORD is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config VASTBASE_DATABASE is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config VASTBASE_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config VASTBASE_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id UUID PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding floatvector({dimension}) NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
class VastbaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: VastbaseVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.VASTBASE
|
||||
|
||||
def _create_connection_pool(self, config: VastbaseVectorConfig):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
config.min_connection,
|
||||
config.max_connection,
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
user=config.user,
|
||||
password=config.password,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||
)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
docs = []
|
||||
for record in cur:
|
||||
docs.append(Document(page_content=record[1], metadata=record[0]))
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
|
||||
# Scenario 1: extract a document fails, resulting in a table not being created.
|
||||
# Then clicking the retry button triggers a delete operation on an empty list.
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:param top_k: The number of nearest neighbors to return, default is 5.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
|
||||
f" ORDER BY distance LIMIT {top_k}",
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector(text) @@ plainto_tsquery(%s)
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
# f"'{query}'" is required in order to account for whitespace in query
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
|
||||
docs = []
|
||||
|
||||
for record in cur:
|
||||
metadata, text, score = record
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# Vastbase 支持的向量维度取值范围为 [1,16000]
|
||||
if dimension <= 16000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class VastbaseVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VASTBASE, collection_name))
|
||||
|
||||
return VastbaseVector(
|
||||
collection_name=collection_name,
|
||||
config=VastbaseVectorConfig(
|
||||
host=dify_config.VASTBASE_HOST or "localhost",
|
||||
port=dify_config.VASTBASE_PORT,
|
||||
user=dify_config.VASTBASE_USER or "dify",
|
||||
password=dify_config.VASTBASE_PASSWORD or "",
|
||||
database=dify_config.VASTBASE_DATABASE or "dify",
|
||||
min_connection=dify_config.VASTBASE_MIN_CONNECTION,
|
||||
max_connection=dify_config.VASTBASE_MAX_CONNECTION,
|
||||
),
|
||||
)
|
@ -74,6 +74,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
|
||||
return PGVectorFactory
|
||||
case VectorType.VASTBASE:
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
|
||||
|
||||
return VastbaseVectorFactory
|
||||
case VectorType.PGVECTO_RS:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
|
@ -7,7 +7,9 @@ class VectorType(StrEnum):
|
||||
MILVUS = "milvus"
|
||||
MYSCALE = "myscale"
|
||||
PGVECTOR = "pgvector"
|
||||
VASTBASE = "vastbase"
|
||||
PGVECTO_RS = "pgvecto-rs"
|
||||
|
||||
QDRANT = "qdrant"
|
||||
RELYT = "relyt"
|
||||
TIDB_VECTOR = "tidb_vector"
|
||||
|
@ -0,0 +1,27 @@
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class VastbaseVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = VastbaseVector(
|
||||
collection_name=self.collection_name,
|
||||
config=VastbaseVectorConfig(
|
||||
host="localhost",
|
||||
port=5434,
|
||||
user="dify",
|
||||
password="Difyai123456",
|
||||
database="dify",
|
||||
min_connection=1,
|
||||
max_connection=5,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_vastbase_vector(setup_mock_redis):
|
||||
VastbaseVectorTest().run_all_tests()
|
@ -441,6 +441,15 @@ PGVECTOR_MAX_CONNECTION=5
|
||||
PGVECTOR_PG_BIGM=false
|
||||
PGVECTOR_PG_BIGM_VERSION=1.2-20240606
|
||||
|
||||
# vastbase configurations, only available when VECTOR_STORE is `vastbase`
|
||||
VASTBASE_HOST=vastbase
|
||||
VASTBASE_PORT=5432
|
||||
VASTBASE_USER=dify
|
||||
VASTBASE_PASSWORD=Difyai123456
|
||||
VASTBASE_DATABASE=dify
|
||||
VASTBASE_MIN_CONNECTION=1
|
||||
VASTBASE_MAX_CONNECTION=5
|
||||
|
||||
# pgvecto-rs configurations, only available when VECTOR_STORE is `pgvecto-rs`
|
||||
PGVECTO_RS_HOST=pgvecto-rs
|
||||
PGVECTO_RS_PORT=5432
|
||||
|
@ -363,6 +363,30 @@ services:
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
|
||||
# get image from https://www.vastdata.com.cn/
|
||||
vastbase:
|
||||
image: vastdata/vastbase-vector
|
||||
profiles:
|
||||
- vastbase
|
||||
restart: always
|
||||
environment:
|
||||
- VB_DBCOMPATIBILITY=PG
|
||||
- VB_DB=dify
|
||||
- VB_USERNAME=dify
|
||||
- VB_PASSWORD=Difyai123456
|
||||
ports:
|
||||
- '5434:5432'
|
||||
volumes:
|
||||
- ./vastbase/lic:/home/vastbase/vastbase/lic
|
||||
- ./vastbase/data:/home/vastbase/data
|
||||
- ./vastbase/backup:/home/vastbase/backup
|
||||
- ./vastbase/backup_log:/home/vastbase/backup_log
|
||||
healthcheck:
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
|
||||
# pgvecto-rs vector store
|
||||
pgvecto-rs:
|
||||
image: tensorchord/pgvecto-rs:pg16-v0.3.0
|
||||
|
@ -163,6 +163,13 @@ x-shared-env: &shared-api-worker-env
|
||||
PGVECTOR_MAX_CONNECTION: ${PGVECTOR_MAX_CONNECTION:-5}
|
||||
PGVECTOR_PG_BIGM: ${PGVECTOR_PG_BIGM:-false}
|
||||
PGVECTOR_PG_BIGM_VERSION: ${PGVECTOR_PG_BIGM_VERSION:-1.2-20240606}
|
||||
VASTBASE_HOST: ${VASTBASE_HOST:-vastbase}
|
||||
VASTBASE_PORT: ${VASTBASE_PORT:-5432}
|
||||
VASTBASE_USER: ${VASTBASE_USER:-dify}
|
||||
VASTBASE_PASSWORD: ${VASTBASE_PASSWORD:-Difyai123456}
|
||||
VASTBASE_DATABASE: ${VASTBASE_DATABASE:-dify}
|
||||
VASTBASE_MIN_CONNECTION: ${VASTBASE_MIN_CONNECTION:-1}
|
||||
VASTBASE_MAX_CONNECTION: ${VASTBASE_MAX_CONNECTION:-5}
|
||||
PGVECTO_RS_HOST: ${PGVECTO_RS_HOST:-pgvecto-rs}
|
||||
PGVECTO_RS_PORT: ${PGVECTO_RS_PORT:-5432}
|
||||
PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres}
|
||||
@ -840,6 +847,30 @@ services:
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
|
||||
# get image from https://www.vastdata.com.cn/
|
||||
vastbase:
|
||||
image: vastdata/vastbase-vector
|
||||
profiles:
|
||||
- vastbase
|
||||
restart: always
|
||||
environment:
|
||||
- VB_DBCOMPATIBILITY=PG
|
||||
- VB_DB=dify
|
||||
- VB_USERNAME=dify
|
||||
- VB_PASSWORD=Difyai123456
|
||||
ports:
|
||||
- '5434:5432'
|
||||
volumes:
|
||||
- ./vastbase/lic:/home/vastbase/vastbase/lic
|
||||
- ./vastbase/data:/home/vastbase/data
|
||||
- ./vastbase/backup:/home/vastbase/backup
|
||||
- ./vastbase/backup_log:/home/vastbase/backup_log
|
||||
healthcheck:
|
||||
test: [ 'CMD', 'pg_isready' ]
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 30
|
||||
|
||||
# pgvecto-rs vector store
|
||||
pgvecto-rs:
|
||||
image: tensorchord/pgvecto-rs:pg16-v0.3.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user