feat: AnalyticDB vector store supports invocation via SQL. (#10802)

Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>
This commit is contained in:
8bitpd 2024-11-18 19:29:54 +08:00 committed by GitHub
parent de6d3e493c
commit 873e9720e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 653 additions and 310 deletions

View File

@ -234,6 +234,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# OpenSearch configuration
OPENSEARCH_HOST=127.0.0.1

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PositiveInt
class AnalyticdbConfig(BaseModel):
@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel):
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
)
ANALYTICDB_HOST: Optional[str] = Field(
default=None, description="The host of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_PORT: PositiveInt = Field(
default=5432, description="The port of the AnalyticDB instance you want to connect to."
)
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")

View File

@ -1,310 +1,62 @@
import json
from typing import Any
from pydantic import BaseModel
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
class AnalyticdbConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = ("dify",)
namespace_password: str = (None,)
metrics: str = ("cosine",)
read_timeout: int = 60000
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
def __init__(self, collection_name: str, config: AnalyticdbConfig):
self._collection_name = collection_name.lower()
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_indexing_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}"
if redis_client.get(collection_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
self.add_texts(texts, embeddings)
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
return self.analyticdb_vector.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
self.analyticdb_vector.delete_by_metadata_field(key, value)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_vector(query_vector)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.metadata.get("vector"),
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e
self.analyticdb_vector.delete()
class AnalyticdbVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
# handle optional params
if dify_config.ANALYTICDB_KEY_ID is None:
raise ValueError("ANALYTICDB_KEY_ID should not be None")
if dify_config.ANALYTICDB_KEY_SECRET is None:
raise ValueError("ANALYTICDB_KEY_SECRET should not be None")
if dify_config.ANALYTICDB_REGION_ID is None:
raise ValueError("ANALYTICDB_REGION_ID should not be None")
if dify_config.ANALYTICDB_INSTANCE_ID is None:
raise ValueError("ANALYTICDB_INSTANCE_ID should not be None")
if dify_config.ANALYTICDB_ACCOUNT is None:
raise ValueError("ANALYTICDB_ACCOUNT should not be None")
if dify_config.ANALYTICDB_PASSWORD is None:
raise ValueError("ANALYTICDB_PASSWORD should not be None")
if dify_config.ANALYTICDB_NAMESPACE is None:
raise ValueError("ANALYTICDB_NAMESPACE should not be None")
if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None:
raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None")
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
),
)
sqlConfig = None
else:
# implemented through sql
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
)
apiConfig = None
return AnalyticdbVector(
collection_name,
apiConfig,
sqlConfig,
)

View File

@ -0,0 +1,309 @@
import json
from typing import Any
from pydantic import BaseModel, model_validator
_import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
region_id: str
instance_id: str
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
metrics: str = "cosine"
read_timeout: int = 60000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["access_key_id"]:
raise ValueError("config ANALYTICDB_KEY_ID is required")
if not values["access_key_secret"]:
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
if not values["region_id"]:
raise ValueError("config ANALYTICDB_REGION_ID is required")
if not values["instance_id"]:
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["namespace_password"]:
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
self.config = config
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
self._client = Client(self._client_config)
self._initialize()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.instance_id}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
self._create_namespace_if_not_exists()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.init_vector_database(request)
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
)
self._client.describe_namespace(request)
except TeaException as e:
if e.statusCode == 404:
request = gpdb_20160503_models.CreateNamespaceRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
)
self._client.create_namespace(request)
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
try:
request = gpdb_20160503_models.DescribeCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
)
self._client.describe_collection(request)
except TeaException as e:
if e.statusCode == 404:
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
full_text_retrieval_fields = "page_content"
request = gpdb_20160503_models.CreateCollectionRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
manager_account=self.config.account,
manager_account_password=self.config.account_password,
namespace=self.config.namespace,
collection=self._collection_name,
dimension=embedding_dimension,
metrics=self.config.metrics,
metadata=metadata,
full_text_retrieval_fields=full_text_retrieval_fields,
)
self._client.create_collection(request)
else:
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
rows=rows,
)
self._client.upsert_collection_data(request)
def text_exists(self, id: str) -> bool:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
ids_str = f"({ids_str})"
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
def delete_by_metadata_field(self, key: str, value: str) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=None,
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value,
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents
def delete(self) -> None:
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
collection=self._collection_name,
dbinstance_id=self.config.instance_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
region_id=self.config.region_id,
)
self._client.delete_collection(request)
except Exception as e:
raise e

View File

@ -0,0 +1,245 @@
import json
import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbVectorBySqlConfig(BaseModel):
host: str
port: int
account: str
account_password: str
min_connection: int
max_connection: int
namespace: str = "dify"
metrics: str = "cosine"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ANALYTICDB_HOST is required")
if not values["port"]:
raise ValueError("config ANALYTICDB_PORT is required")
if not values["account"]:
raise ValueError("config ANALYTICDB_ACCOUNT is required")
if not values["account_password"]:
raise ValueError("config ANALYTICDB_PASSWORD is required")
if not values["min_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
if not values["max_connection"]:
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
if values["min_connection"] > values["max_connection"]:
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
return values
class AnalyticdbVectorBySql:
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
self._collection_name = collection_name.lower()
self.databaseName = "knowledgebase"
self.config = config
self.table_name = f"{self.config.namespace}.{self._collection_name}"
self.pool = None
self._initialize()
if not self.pool:
self.pool = self._create_connection_pool()
def _initialize(self) -> None:
cache_key = f"vector_initialize_{self.config.host}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
database_exist_cache_key = f"vector_initialize_{self.config.host}"
if redis_client.get(database_exist_cache_key):
return
self._initialize_vector_database()
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _create_connection_pool(self):
return psycopg2.pool.SimpleConnectionPool(
self.config.min_connection,
self.config.max_connection,
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database=self.databaseName,
)
@contextmanager
def _get_cursor(self):
conn = self.pool.getconn()
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
self.pool.putconn(conn)
def _initialize_vector_database(self) -> None:
conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
user=self.config.account,
password=self.config.account_password,
database="postgres",
)
conn.autocommit = True
cur = conn.cursor()
try:
cur.execute(f"CREATE DATABASE {self.databaseName}")
except Exception as e:
if "already exists" in str(e):
return
raise e
finally:
cur.close()
conn.close()
self.pool = self._create_connection_pool()
with self._get_cursor() as cur:
try:
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
except Exception as e:
if "already exists" not in str(e):
raise e
cur.execute(
"CREATE OR REPLACE FUNCTION "
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
"AS words_only;$function$"
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
cur.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
f"id text PRIMARY KEY,"
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
f"to_tsvector TSVECTOR"
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
)
if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx"
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
f"pq_enable=0, external_storage=0)"
)
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
except Exception as e:
if "does not exist" not in str(e):
raise e
def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
try:
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
except Exception as e:
if "does not exist" not in str(e):
raise e
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector)
query_vector_str = "{" + query_vector_str[1:-1] + "}"
cur.execute(
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,),
)
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
ORDER BY score DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
documents = []
for record in cur:
id, vector, page_content, metadata, score = record
metadata["score"] = score
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@ -1,16 +1,32 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
class AnalyticdbVectorTest(AbstractVectorTest):
def __init__(self):
def __init__(self, config_type: str):
super().__init__()
# Analyticdb requires collection_name length less than 60.
# it's ok for normal usage.
self.collection_name = self.collection_name.replace("_test", "")
if config_type == "sql":
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
config=AnalyticdbConfig(
sql_config=AnalyticdbVectorBySqlConfig(
host="test_host",
port=5432,
account="test_account",
account_password="test_passwd",
namespace="difytest_namespace",
),
api_config=None,
)
else:
self.vector = AnalyticdbVector(
collection_name=self.collection_name,
sql_config=None,
api_config=AnalyticdbVectorOpenAPIConfig(
access_key_id="test_key_id",
access_key_secret="test_key_secret",
region_id="test_region",
@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest):
def test_chroma_vector(setup_mock_redis):
AnalyticdbVectorTest().run_all_tests()
AnalyticdbVectorTest("api").run_all_tests()
AnalyticdbVectorTest("sql").run_all_tests()

View File

@ -450,6 +450,10 @@ ANALYTICDB_ACCOUNT=testaccount
ANALYTICDB_PASSWORD=testpassword
ANALYTICDB_NAMESPACE=dify
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
ANALYTICDB_HOST=gp-test.aliyuncs.com
ANALYTICDB_PORT=5432
ANALYTICDB_MIN_CONNECTION=1
ANALYTICDB_MAX_CONNECTION=5
# TiDB vector configurations, only available when VECTOR_STORE is `tidb`
TIDB_VECTOR_HOST=tidb

View File

@ -185,6 +185,10 @@ x-shared-env: &shared-api-worker-env
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-}
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-}
ANALYTICDB_HOST: ${ANALYTICDB_HOST:-}
ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}