mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 03:45:55 +08:00
feat: AnalyticDB vector store supports invocation via SQL. (#10802)
Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>
This commit is contained in:
parent
de6d3e493c
commit
873e9720e9
@ -234,6 +234,10 @@ ANALYTICDB_ACCOUNT=testaccount
|
|||||||
ANALYTICDB_PASSWORD=testpassword
|
ANALYTICDB_PASSWORD=testpassword
|
||||||
ANALYTICDB_NAMESPACE=dify
|
ANALYTICDB_NAMESPACE=dify
|
||||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||||
|
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||||
|
ANALYTICDB_PORT=5432
|
||||||
|
ANALYTICDB_MIN_CONNECTION=1
|
||||||
|
ANALYTICDB_MAX_CONNECTION=5
|
||||||
|
|
||||||
# OpenSearch configuration
|
# OpenSearch configuration
|
||||||
OPENSEARCH_HOST=127.0.0.1
|
OPENSEARCH_HOST=127.0.0.1
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PositiveInt
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbConfig(BaseModel):
|
class AnalyticdbConfig(BaseModel):
|
||||||
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
|
|||||||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||||
" (if namespace feature is enabled).",
|
" (if namespace feature is enabled).",
|
||||||
)
|
)
|
||||||
|
ANALYTICDB_HOST: Optional[str] = Field(
|
||||||
|
default=None, description="The host of the AnalyticDB instance you want to connect to."
|
||||||
|
)
|
||||||
|
ANALYTICDB_PORT: PositiveInt = Field(
|
||||||
|
default=5432, description="The port of the AnalyticDB instance you want to connect to."
|
||||||
|
)
|
||||||
|
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
|
||||||
|
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")
|
||||||
|
@ -1,310 +1,62 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
_import_err_msg = (
|
|
||||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
|
||||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
|
||||||
)
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||||
|
AnalyticdbVectorOpenAPI,
|
||||||
|
AnalyticdbVectorOpenAPIConfig,
|
||||||
|
)
|
||||||
|
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbConfig(BaseModel):
|
|
||||||
access_key_id: str
|
|
||||||
access_key_secret: str
|
|
||||||
region_id: str
|
|
||||||
instance_id: str
|
|
||||||
account: str
|
|
||||||
account_password: str
|
|
||||||
namespace: str = ("dify",)
|
|
||||||
namespace_password: str = (None,)
|
|
||||||
metrics: str = ("cosine",)
|
|
||||||
read_timeout: int = 60000
|
|
||||||
|
|
||||||
def to_analyticdb_client_params(self):
|
|
||||||
return {
|
|
||||||
"access_key_id": self.access_key_id,
|
|
||||||
"access_key_secret": self.access_key_secret,
|
|
||||||
"region_id": self.region_id,
|
|
||||||
"read_timeout": self.read_timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVector(BaseVector):
|
class AnalyticdbVector(BaseVector):
|
||||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
def __init__(
|
||||||
self._collection_name = collection_name.lower()
|
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||||
try:
|
):
|
||||||
from alibabacloud_gpdb20160503.client import Client
|
super().__init__(collection_name)
|
||||||
from alibabacloud_tea_openapi import models as open_api_models
|
if api_config is not None:
|
||||||
except:
|
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||||
raise ImportError(_import_err_msg)
|
else:
|
||||||
self.config = config
|
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
|
||||||
self._client = Client(self._client_config)
|
|
||||||
self._initialize()
|
|
||||||
|
|
||||||
def _initialize(self) -> None:
|
|
||||||
cache_key = f"vector_indexing_{self.config.instance_id}"
|
|
||||||
lock_name = f"{cache_key}_lock"
|
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
|
||||||
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
|
|
||||||
if redis_client.get(collection_exist_cache_key):
|
|
||||||
return
|
|
||||||
self._initialize_vector_database()
|
|
||||||
self._create_namespace_if_not_exists()
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
||||||
|
|
||||||
def _initialize_vector_database(self) -> None:
|
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
|
||||||
|
|
||||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
manager_account=self.config.account,
|
|
||||||
manager_account_password=self.config.account_password,
|
|
||||||
)
|
|
||||||
self._client.init_vector_database(request)
|
|
||||||
|
|
||||||
def _create_namespace_if_not_exists(self) -> None:
|
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
|
||||||
from Tea.exceptions import TeaException
|
|
||||||
|
|
||||||
try:
|
|
||||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
manager_account=self.config.account,
|
|
||||||
manager_account_password=self.config.account_password,
|
|
||||||
)
|
|
||||||
self._client.describe_namespace(request)
|
|
||||||
except TeaException as e:
|
|
||||||
if e.statusCode == 404:
|
|
||||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
manager_account=self.config.account,
|
|
||||||
manager_account_password=self.config.account_password,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
)
|
|
||||||
self._client.create_namespace(request)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
|
||||||
|
|
||||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
|
||||||
from Tea.exceptions import TeaException
|
|
||||||
|
|
||||||
cache_key = f"vector_indexing_{self._collection_name}"
|
|
||||||
lock_name = f"{cache_key}_lock"
|
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
|
||||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
|
||||||
if redis_client.get(collection_exist_cache_key):
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
)
|
|
||||||
self._client.describe_collection(request)
|
|
||||||
except TeaException as e:
|
|
||||||
if e.statusCode == 404:
|
|
||||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
|
||||||
full_text_retrieval_fields = "page_content"
|
|
||||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
manager_account=self.config.account,
|
|
||||||
manager_account_password=self.config.account_password,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
collection=self._collection_name,
|
|
||||||
dimension=embedding_dimension,
|
|
||||||
metrics=self.config.metrics,
|
|
||||||
metadata=metadata,
|
|
||||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
|
||||||
)
|
|
||||||
self._client.create_collection(request)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.ANALYTICDB
|
return VectorType.ANALYTICDB
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
dimension = len(embeddings[0])
|
dimension = len(embeddings[0])
|
||||||
self._create_collection_if_not_exists(dimension)
|
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||||
self.add_texts(texts, embeddings)
|
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||||
|
|
||||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||||
|
|
||||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
|
||||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
|
||||||
metadata = {
|
|
||||||
"ref_doc_id": doc.metadata["doc_id"],
|
|
||||||
"page_content": doc.page_content,
|
|
||||||
"metadata_": json.dumps(doc.metadata),
|
|
||||||
}
|
|
||||||
rows.append(
|
|
||||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
|
||||||
vector=embedding,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
rows=rows,
|
|
||||||
)
|
|
||||||
self._client.upsert_collection_data(request)
|
|
||||||
|
|
||||||
def text_exists(self, id: str) -> bool:
|
def text_exists(self, id: str) -> bool:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
return self.analyticdb_vector.text_exists(id)
|
||||||
|
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
metrics=self.config.metrics,
|
|
||||||
include_values=True,
|
|
||||||
vector=None,
|
|
||||||
content=None,
|
|
||||||
top_k=1,
|
|
||||||
filter=f"ref_doc_id='{id}'",
|
|
||||||
)
|
|
||||||
response = self._client.query_collection_data(request)
|
|
||||||
return len(response.body.matches.match) > 0
|
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str]) -> None:
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
self.analyticdb_vector.delete_by_ids(ids)
|
||||||
|
|
||||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
|
||||||
ids_str = f"({ids_str})"
|
|
||||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
collection_data=None,
|
|
||||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
|
||||||
)
|
|
||||||
self._client.delete_collection_data(request)
|
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||||
|
|
||||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
collection_data=None,
|
|
||||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
|
||||||
)
|
|
||||||
self._client.delete_collection_data(request)
|
|
||||||
|
|
||||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
return self.analyticdb_vector.search_by_vector(query_vector)
|
||||||
|
|
||||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
include_values=kwargs.pop("include_values", True),
|
|
||||||
metrics=self.config.metrics,
|
|
||||||
vector=query_vector,
|
|
||||||
content=None,
|
|
||||||
top_k=kwargs.get("top_k", 4),
|
|
||||||
filter=None,
|
|
||||||
)
|
|
||||||
response = self._client.query_collection_data(request)
|
|
||||||
documents = []
|
|
||||||
for match in response.body.matches.match:
|
|
||||||
if match.score > score_threshold:
|
|
||||||
metadata = json.loads(match.metadata.get("metadata_"))
|
|
||||||
metadata["score"] = match.score
|
|
||||||
doc = Document(
|
|
||||||
page_content=match.metadata.get("page_content"),
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
documents.append(doc)
|
|
||||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||||
|
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
|
||||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
collection=self._collection_name,
|
|
||||||
include_values=kwargs.pop("include_values", True),
|
|
||||||
metrics=self.config.metrics,
|
|
||||||
vector=None,
|
|
||||||
content=query,
|
|
||||||
top_k=kwargs.get("top_k", 4),
|
|
||||||
filter=None,
|
|
||||||
)
|
|
||||||
response = self._client.query_collection_data(request)
|
|
||||||
documents = []
|
|
||||||
for match in response.body.matches.match:
|
|
||||||
if match.score > score_threshold:
|
|
||||||
metadata = json.loads(match.metadata.get("metadata_"))
|
|
||||||
metadata["score"] = match.score
|
|
||||||
doc = Document(
|
|
||||||
page_content=match.metadata.get("page_content"),
|
|
||||||
vector=match.metadata.get("vector"),
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
documents.append(doc)
|
|
||||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
|
||||||
return documents
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
try:
|
self.analyticdb_vector.delete()
|
||||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
|
||||||
|
|
||||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
|
||||||
collection=self._collection_name,
|
|
||||||
dbinstance_id=self.config.instance_id,
|
|
||||||
namespace=self.config.namespace,
|
|
||||||
namespace_password=self.config.namespace_password,
|
|
||||||
region_id=self.config.region_id,
|
|
||||||
)
|
|
||||||
self._client.delete_collection(request)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
collection_name = class_prefix.lower()
|
collection_name = class_prefix.lower()
|
||||||
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
|||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||||
|
|
||||||
# handle optional params
|
if dify_config.ANALYTICDB_HOST is None:
|
||||||
if dify_config.ANALYTICDB_KEY_ID is None:
|
# implemented through OpenAPI
|
||||||
raise ValueError("ANALYTICDB_KEY_ID should not be None")
|
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||||
if dify_config.ANALYTICDB_KEY_SECRET is None:
|
|
||||||
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
|
|
||||||
if dify_config.ANALYTICDB_REGION_ID is None:
|
|
||||||
raise ValueError("ANALYTICDB_REGION_ID should not be None")
|
|
||||||
if dify_config.ANALYTICDB_INSTANCE_ID is None:
|
|
||||||
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
|
|
||||||
if dify_config.ANALYTICDB_ACCOUNT is None:
|
|
||||||
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
|
|
||||||
if dify_config.ANALYTICDB_PASSWORD is None:
|
|
||||||
raise ValueError("ANALYTICDB_PASSWORD should not be None")
|
|
||||||
if dify_config.ANALYTICDB_NAMESPACE is None:
|
|
||||||
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
|
|
||||||
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
|
|
||||||
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
|
|
||||||
return AnalyticdbVector(
|
|
||||||
collection_name,
|
|
||||||
AnalyticdbConfig(
|
|
||||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||||
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
|||||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||||
),
|
)
|
||||||
|
sqlConfig = None
|
||||||
|
else:
|
||||||
|
# implemented through sql
|
||||||
|
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||||
|
host=dify_config.ANALYTICDB_HOST,
|
||||||
|
port=dify_config.ANALYTICDB_PORT,
|
||||||
|
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||||
|
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||||
|
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||||
|
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||||
|
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||||
|
)
|
||||||
|
apiConfig = None
|
||||||
|
return AnalyticdbVector(
|
||||||
|
collection_name,
|
||||||
|
apiConfig,
|
||||||
|
sqlConfig,
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,309 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
_import_err_msg = (
|
||||||
|
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||||
|
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||||
|
)
|
||||||
|
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||||
|
access_key_id: str
|
||||||
|
access_key_secret: str
|
||||||
|
region_id: str
|
||||||
|
instance_id: str
|
||||||
|
account: str
|
||||||
|
account_password: str
|
||||||
|
namespace: str = "dify"
|
||||||
|
namespace_password: str = (None,)
|
||||||
|
metrics: str = "cosine"
|
||||||
|
read_timeout: int = 60000
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values["access_key_id"]:
|
||||||
|
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||||
|
if not values["access_key_secret"]:
|
||||||
|
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||||
|
if not values["region_id"]:
|
||||||
|
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||||
|
if not values["instance_id"]:
|
||||||
|
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||||
|
if not values["account"]:
|
||||||
|
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||||
|
if not values["account_password"]:
|
||||||
|
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||||
|
if not values["namespace_password"]:
|
||||||
|
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||||
|
return values
|
||||||
|
|
||||||
|
def to_analyticdb_client_params(self):
|
||||||
|
return {
|
||||||
|
"access_key_id": self.access_key_id,
|
||||||
|
"access_key_secret": self.access_key_secret,
|
||||||
|
"region_id": self.region_id,
|
||||||
|
"read_timeout": self.read_timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbVectorOpenAPI:
|
||||||
|
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||||
|
try:
|
||||||
|
from alibabacloud_gpdb20160503.client import Client
|
||||||
|
from alibabacloud_tea_openapi import models as open_api_models
|
||||||
|
except:
|
||||||
|
raise ImportError(_import_err_msg)
|
||||||
|
self._collection_name = collection_name.lower()
|
||||||
|
self.config = config
|
||||||
|
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||||
|
self._client = Client(self._client_config)
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||||
|
lock_name = f"{cache_key}_lock"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||||
|
if redis_client.get(database_exist_cache_key):
|
||||||
|
return
|
||||||
|
self._initialize_vector_database()
|
||||||
|
self._create_namespace_if_not_exists()
|
||||||
|
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
def _initialize_vector_database(self) -> None:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
manager_account=self.config.account,
|
||||||
|
manager_account_password=self.config.account_password,
|
||||||
|
)
|
||||||
|
self._client.init_vector_database(request)
|
||||||
|
|
||||||
|
def _create_namespace_if_not_exists(self) -> None:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
from Tea.exceptions import TeaException
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
manager_account=self.config.account,
|
||||||
|
manager_account_password=self.config.account_password,
|
||||||
|
)
|
||||||
|
self._client.describe_namespace(request)
|
||||||
|
except TeaException as e:
|
||||||
|
if e.statusCode == 404:
|
||||||
|
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
manager_account=self.config.account,
|
||||||
|
manager_account_password=self.config.account_password,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
)
|
||||||
|
self._client.create_namespace(request)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||||
|
|
||||||
|
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
from Tea.exceptions import TeaException
|
||||||
|
|
||||||
|
cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
lock_name = f"{cache_key}_lock"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
)
|
||||||
|
self._client.describe_collection(request)
|
||||||
|
except TeaException as e:
|
||||||
|
if e.statusCode == 404:
|
||||||
|
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||||
|
full_text_retrieval_fields = "page_content"
|
||||||
|
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
manager_account=self.config.account,
|
||||||
|
manager_account_password=self.config.account_password,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
collection=self._collection_name,
|
||||||
|
dimension=embedding_dimension,
|
||||||
|
metrics=self.config.metrics,
|
||||||
|
metadata=metadata,
|
||||||
|
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||||
|
)
|
||||||
|
self._client.create_collection(request)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||||
|
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||||
|
metadata = {
|
||||||
|
"ref_doc_id": doc.metadata["doc_id"],
|
||||||
|
"page_content": doc.page_content,
|
||||||
|
"metadata_": json.dumps(doc.metadata),
|
||||||
|
}
|
||||||
|
rows.append(
|
||||||
|
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||||
|
vector=embedding,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
rows=rows,
|
||||||
|
)
|
||||||
|
self._client.upsert_collection_data(request)
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
metrics=self.config.metrics,
|
||||||
|
include_values=True,
|
||||||
|
vector=None,
|
||||||
|
content=None,
|
||||||
|
top_k=1,
|
||||||
|
filter=f"ref_doc_id='{id}'",
|
||||||
|
)
|
||||||
|
response = self._client.query_collection_data(request)
|
||||||
|
return len(response.body.matches.match) > 0
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||||
|
ids_str = f"({ids_str})"
|
||||||
|
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
collection_data=None,
|
||||||
|
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||||
|
)
|
||||||
|
self._client.delete_collection_data(request)
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
collection_data=None,
|
||||||
|
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||||
|
)
|
||||||
|
self._client.delete_collection_data(request)
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||||
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
include_values=kwargs.pop("include_values", True),
|
||||||
|
metrics=self.config.metrics,
|
||||||
|
vector=query_vector,
|
||||||
|
content=None,
|
||||||
|
top_k=kwargs.get("top_k", 4),
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
response = self._client.query_collection_data(request)
|
||||||
|
documents = []
|
||||||
|
for match in response.body.matches.match:
|
||||||
|
if match.score > score_threshold:
|
||||||
|
metadata = json.loads(match.metadata.get("metadata_"))
|
||||||
|
metadata["score"] = match.score
|
||||||
|
doc = Document(
|
||||||
|
page_content=match.metadata.get("page_content"),
|
||||||
|
vector=match.values.value,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
collection=self._collection_name,
|
||||||
|
include_values=kwargs.pop("include_values", True),
|
||||||
|
metrics=self.config.metrics,
|
||||||
|
vector=None,
|
||||||
|
content=query,
|
||||||
|
top_k=kwargs.get("top_k", 4),
|
||||||
|
filter=None,
|
||||||
|
)
|
||||||
|
response = self._client.query_collection_data(request)
|
||||||
|
documents = []
|
||||||
|
for match in response.body.matches.match:
|
||||||
|
if match.score > score_threshold:
|
||||||
|
metadata = json.loads(match.metadata.get("metadata_"))
|
||||||
|
metadata["score"] = match.score
|
||||||
|
doc = Document(
|
||||||
|
page_content=match.metadata.get("page_content"),
|
||||||
|
vector=match.values.value,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
try:
|
||||||
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||||
|
|
||||||
|
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||||
|
collection=self._collection_name,
|
||||||
|
dbinstance_id=self.config.instance_id,
|
||||||
|
namespace=self.config.namespace,
|
||||||
|
namespace_password=self.config.namespace_password,
|
||||||
|
region_id=self.config.region_id,
|
||||||
|
)
|
||||||
|
self._client.delete_collection(request)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
245
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
Normal file
245
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import psycopg2.extras
|
||||||
|
import psycopg2.pool
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||||
|
host: str
|
||||||
|
port: int
|
||||||
|
account: str
|
||||||
|
account_password: str
|
||||||
|
min_connection: int
|
||||||
|
max_connection: int
|
||||||
|
namespace: str = "dify"
|
||||||
|
metrics: str = "cosine"
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, values: dict) -> dict:
|
||||||
|
if not values["host"]:
|
||||||
|
raise ValueError("config ANALYTICDB_HOST is required")
|
||||||
|
if not values["port"]:
|
||||||
|
raise ValueError("config ANALYTICDB_PORT is required")
|
||||||
|
if not values["account"]:
|
||||||
|
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||||
|
if not values["account_password"]:
|
||||||
|
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||||
|
if not values["min_connection"]:
|
||||||
|
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||||
|
if not values["max_connection"]:
|
||||||
|
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||||
|
if values["min_connection"] > values["max_connection"]:
|
||||||
|
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbVectorBySql:
|
||||||
|
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||||
|
self._collection_name = collection_name.lower()
|
||||||
|
self.databaseName = "knowledgebase"
|
||||||
|
self.config = config
|
||||||
|
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||||
|
self.pool = None
|
||||||
|
self._initialize()
|
||||||
|
if not self.pool:
|
||||||
|
self.pool = self._create_connection_pool()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
cache_key = f"vector_initialize_{self.config.host}"
|
||||||
|
lock_name = f"{cache_key}_lock"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||||
|
if redis_client.get(database_exist_cache_key):
|
||||||
|
return
|
||||||
|
self._initialize_vector_database()
|
||||||
|
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
def _create_connection_pool(self):
|
||||||
|
return psycopg2.pool.SimpleConnectionPool(
|
||||||
|
self.config.min_connection,
|
||||||
|
self.config.max_connection,
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
user=self.config.account,
|
||||||
|
password=self.config.account_password,
|
||||||
|
database=self.databaseName,
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _get_cursor(self):
|
||||||
|
conn = self.pool.getconn()
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
yield cur
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
conn.commit()
|
||||||
|
self.pool.putconn(conn)
|
||||||
|
|
||||||
|
def _initialize_vector_database(self) -> None:
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=self.config.host,
|
||||||
|
port=self.config.port,
|
||||||
|
user=self.config.account,
|
||||||
|
password=self.config.account_password,
|
||||||
|
database="postgres",
|
||||||
|
)
|
||||||
|
conn.autocommit = True
|
||||||
|
cur = conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||||
|
except Exception as e:
|
||||||
|
if "already exists" in str(e):
|
||||||
|
return
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
self.pool = self._create_connection_pool()
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
try:
|
||||||
|
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||||
|
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||||
|
except Exception as e:
|
||||||
|
if "already exists" not in str(e):
|
||||||
|
raise e
|
||||||
|
cur.execute(
|
||||||
|
"CREATE OR REPLACE FUNCTION "
|
||||||
|
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||||
|
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||||
|
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||||
|
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||||
|
"AS words_only;$function$"
|
||||||
|
)
|
||||||
|
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||||
|
|
||||||
|
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||||
|
cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
lock_name = f"{cache_key}_lock"
|
||||||
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
|
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||||
|
if redis_client.get(collection_exist_cache_key):
|
||||||
|
return
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||||
|
f"id text PRIMARY KEY,"
|
||||||
|
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||||
|
f"to_tsvector TSVECTOR"
|
||||||
|
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||||
|
)
|
||||||
|
if embedding_dimension is not None:
|
||||||
|
index_name = f"{self._collection_name}_embedding_idx"
|
||||||
|
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||||
|
cur.execute(
|
||||||
|
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||||
|
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||||
|
f"pq_enable=0, external_storage=0)"
|
||||||
|
)
|
||||||
|
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||||
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
|
values = []
|
||||||
|
id_prefix = str(uuid.uuid4()) + "_"
|
||||||
|
sql = f"""
|
||||||
|
INSERT INTO {self.table_name}
|
||||||
|
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||||
|
"""
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
values.append(
|
||||||
|
(
|
||||||
|
id_prefix + str(i),
|
||||||
|
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||||
|
embeddings[i],
|
||||||
|
doc.page_content,
|
||||||
|
json.dumps(doc.metadata),
|
||||||
|
doc.page_content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
psycopg2.extras.execute_batch(cur, sql, values)
|
||||||
|
|
||||||
|
def text_exists(self, id: str) -> bool:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||||
|
return cur.fetchone() is not None
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids: list[str]) -> None:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
try:
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
|
||||||
|
except Exception as e:
|
||||||
|
if "does not exist" not in str(e):
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
try:
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||||
|
except Exception as e:
|
||||||
|
if "does not exist" not in str(e):
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||||
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
query_vector_str = json.dumps(query_vector)
|
||||||
|
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||||
|
cur.execute(
|
||||||
|
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||||
|
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||||
|
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||||
|
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||||
|
(query_vector_str,),
|
||||||
|
)
|
||||||
|
documents = []
|
||||||
|
for record in cur:
|
||||||
|
id, vector, score, page_content, metadata = record
|
||||||
|
if score > score_threshold:
|
||||||
|
metadata["score"] = score
|
||||||
|
doc = Document(
|
||||||
|
page_content=page_content,
|
||||||
|
vector=vector,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
|
top_k = kwargs.get("top_k", 4)
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
f"""SELECT id, vector, page_content, metadata_,
|
||||||
|
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||||
|
FROM {self.table_name}
|
||||||
|
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT {top_k}""",
|
||||||
|
(f"'{query}'", f"'{query}'"),
|
||||||
|
)
|
||||||
|
documents = []
|
||||||
|
for record in cur:
|
||||||
|
id, vector, page_content, metadata, score = record
|
||||||
|
metadata["score"] = score
|
||||||
|
doc = Document(
|
||||||
|
page_content=page_content,
|
||||||
|
vector=vector,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
documents.append(doc)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
with self._get_cursor() as cur:
|
||||||
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
@ -1,27 +1,43 @@
|
|||||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
|
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
|
||||||
|
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||||
|
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||||
|
|
||||||
|
|
||||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||||
def __init__(self):
|
def __init__(self, config_type: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Analyticdb requires collection_name length less than 60.
|
# Analyticdb requires collection_name length less than 60.
|
||||||
# it's ok for normal usage.
|
# it's ok for normal usage.
|
||||||
self.collection_name = self.collection_name.replace("_test", "")
|
self.collection_name = self.collection_name.replace("_test", "")
|
||||||
self.vector = AnalyticdbVector(
|
if config_type == "sql":
|
||||||
collection_name=self.collection_name,
|
self.vector = AnalyticdbVector(
|
||||||
config=AnalyticdbConfig(
|
collection_name=self.collection_name,
|
||||||
access_key_id="test_key_id",
|
sql_config=AnalyticdbVectorBySqlConfig(
|
||||||
access_key_secret="test_key_secret",
|
host="test_host",
|
||||||
region_id="test_region",
|
port=5432,
|
||||||
instance_id="test_id",
|
account="test_account",
|
||||||
account="test_account",
|
account_password="test_passwd",
|
||||||
account_password="test_passwd",
|
namespace="difytest_namespace",
|
||||||
namespace="difytest_namespace",
|
),
|
||||||
collection="difytest_collection",
|
api_config=None,
|
||||||
namespace_password="test_passwd",
|
)
|
||||||
),
|
else:
|
||||||
)
|
self.vector = AnalyticdbVector(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
sql_config=None,
|
||||||
|
api_config=AnalyticdbVectorOpenAPIConfig(
|
||||||
|
access_key_id="test_key_id",
|
||||||
|
access_key_secret="test_key_secret",
|
||||||
|
region_id="test_region",
|
||||||
|
instance_id="test_id",
|
||||||
|
account="test_account",
|
||||||
|
account_password="test_passwd",
|
||||||
|
namespace="difytest_namespace",
|
||||||
|
collection="difytest_collection",
|
||||||
|
namespace_password="test_passwd",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def run_all_tests(self):
|
def run_all_tests(self):
|
||||||
self.vector.delete()
|
self.vector.delete()
|
||||||
@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest):
|
|||||||
|
|
||||||
|
|
||||||
def test_chroma_vector(setup_mock_redis):
|
def test_chroma_vector(setup_mock_redis):
|
||||||
AnalyticdbVectorTest().run_all_tests()
|
AnalyticdbVectorTest("api").run_all_tests()
|
||||||
|
AnalyticdbVectorTest("sql").run_all_tests()
|
||||||
|
@ -450,6 +450,10 @@ ANALYTICDB_ACCOUNT=testaccount
|
|||||||
ANALYTICDB_PASSWORD=testpassword
|
ANALYTICDB_PASSWORD=testpassword
|
||||||
ANALYTICDB_NAMESPACE=dify
|
ANALYTICDB_NAMESPACE=dify
|
||||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||||
|
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||||
|
ANALYTICDB_PORT=5432
|
||||||
|
ANALYTICDB_MIN_CONNECTION=1
|
||||||
|
ANALYTICDB_MAX_CONNECTION=5
|
||||||
|
|
||||||
# TiDB vector configurations, only available when VECTOR_STORE is `tidb`
|
# TiDB vector configurations, only available when VECTOR_STORE is `tidb`
|
||||||
TIDB_VECTOR_HOST=tidb
|
TIDB_VECTOR_HOST=tidb
|
||||||
|
@ -185,6 +185,10 @@ x-shared-env: &shared-api-worker-env
|
|||||||
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-}
|
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-}
|
||||||
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
|
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
|
||||||
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-}
|
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-}
|
||||||
|
ANALYTICDB_HOST: ${ANALYTICDB_HOST:-}
|
||||||
|
ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
|
||||||
|
ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
|
||||||
|
ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
|
||||||
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
|
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
|
||||||
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
|
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
|
||||||
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
|
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user