diff --git a/api/.env.example b/api/.env.example index 5751605b48..1a242a3daa 100644 --- a/api/.env.example +++ b/api/.env.example @@ -234,6 +234,10 @@ ANALYTICDB_ACCOUNT=testaccount ANALYTICDB_PASSWORD=testpassword ANALYTICDB_NAMESPACE=dify ANALYTICDB_NAMESPACE_PASSWORD=difypassword +ANALYTICDB_HOST=gp-test.aliyuncs.com +ANALYTICDB_PORT=5432 +ANALYTICDB_MIN_CONNECTION=1 +ANALYTICDB_MAX_CONNECTION=5 # OpenSearch configuration OPENSEARCH_HOST=127.0.0.1 diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index 247a8ea555..53cfaae43e 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PositiveInt class AnalyticdbConfig(BaseModel): @@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel): description="The password for accessing the specified namespace within the AnalyticDB instance" " (if namespace feature is enabled).", ) + ANALYTICDB_HOST: Optional[str] = Field( + default=None, description="The host of the AnalyticDB instance you want to connect to." + ) + ANALYTICDB_PORT: PositiveInt = Field( + default=5432, description="The port of the AnalyticDB instance you want to connect to." + ) + ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.") + ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.") diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index c77cb87376..09104ae422 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -1,310 +1,62 @@ import json from typing import Any -from pydantic import BaseModel - -_import_err_msg = ( - "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " - "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" -) - from configs import dify_config +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document -from extensions.ext_redis import redis_client from models.dataset import Dataset -class AnalyticdbConfig(BaseModel): - access_key_id: str - access_key_secret: str - region_id: str - instance_id: str - account: str - account_password: str - namespace: str = ("dify",) - namespace_password: str = (None,) - metrics: str = ("cosine",) - read_timeout: int = 60000 - - def to_analyticdb_client_params(self): - return { - "access_key_id": self.access_key_id, - "access_key_secret": self.access_key_secret, - "region_id": self.region_id, - "read_timeout": self.read_timeout, - } - - class AnalyticdbVector(BaseVector): - def __init__(self, collection_name: str, config: AnalyticdbConfig): - self._collection_name = collection_name.lower() - try: - from alibabacloud_gpdb20160503.client import Client - from alibabacloud_tea_openapi import models as open_api_models - except: - raise ImportError(_import_err_msg) - self.config = config - self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) - self._client = Client(self._client_config) - self._initialize() - - def _initialize(self) -> None: - cache_key = f"vector_indexing_{self.config.instance_id}" - lock_name = f"{cache_key}_lock" - with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}" - if redis_client.get(collection_exist_cache_key): - return - self._initialize_vector_database() - self._create_namespace_if_not_exists() - redis_client.set(collection_exist_cache_key, 1, ex=3600) - - def _initialize_vector_database(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - request = gpdb_20160503_models.InitVectorDatabaseRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - manager_account=self.config.account, - manager_account_password=self.config.account_password, - ) - self._client.init_vector_database(request) - - def _create_namespace_if_not_exists(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException - - try: - request = gpdb_20160503_models.DescribeNamespaceRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - manager_account=self.config.account, - manager_account_password=self.config.account_password, - ) - self._client.describe_namespace(request) - except TeaException as e: - if e.statusCode == 404: - request = gpdb_20160503_models.CreateNamespaceRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - manager_account=self.config.account, - manager_account_password=self.config.account_password, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - ) - self._client.create_namespace(request) - else: - raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") - - def _create_collection_if_not_exists(self, embedding_dimension: int): - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException - - cache_key = f"vector_indexing_{self._collection_name}" - lock_name = f"{cache_key}_lock" - with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f"vector_indexing_{self._collection_name}" - if redis_client.get(collection_exist_cache_key): - return - try: - request = gpdb_20160503_models.DescribeCollectionRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - ) - self._client.describe_collection(request) - except TeaException as e: - if e.statusCode == 404: - metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}' - full_text_retrieval_fields = "page_content" - request = gpdb_20160503_models.CreateCollectionRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - manager_account=self.config.account, - manager_account_password=self.config.account_password, - namespace=self.config.namespace, - collection=self._collection_name, - dimension=embedding_dimension, - metrics=self.config.metrics, - metadata=metadata, - full_text_retrieval_fields=full_text_retrieval_fields, - ) - self._client.create_collection(request) - else: - raise ValueError(f"failed to create collection {self._collection_name}: {e}") - redis_client.set(collection_exist_cache_key, 1, ex=3600) + def __init__( + self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig + ): + super().__init__(collection_name) + if api_config is not None: + self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) + else: + self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) def get_type(self) -> str: return VectorType.ANALYTICDB def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) - self._create_collection_if_not_exists(dimension) - self.add_texts(texts, embeddings) + self.analyticdb_vector._create_collection_if_not_exists(dimension) + self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] - for doc, embedding in zip(documents, embeddings, strict=True): - metadata = { - "ref_doc_id": doc.metadata["doc_id"], - "page_content": doc.page_content, - "metadata_": json.dumps(doc.metadata), - } - rows.append( - gpdb_20160503_models.UpsertCollectionDataRequestRows( - vector=embedding, - metadata=metadata, - ) - ) - request = gpdb_20160503_models.UpsertCollectionDataRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - rows=rows, - ) - self._client.upsert_collection_data(request) + def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.analyticdb_vector.add_texts(texts, embeddings) def text_exists(self, id: str) -> bool: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - request = gpdb_20160503_models.QueryCollectionDataRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - metrics=self.config.metrics, - include_values=True, - vector=None, - content=None, - top_k=1, - filter=f"ref_doc_id='{id}'", - ) - response = self._client.query_collection_data(request) - return len(response.body.matches.match) > 0 + return self.analyticdb_vector.text_exists(id) def delete_by_ids(self, ids: list[str]) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - ids_str = ",".join(f"'{id}'" for id in ids) - ids_str = f"({ids_str})" - request = gpdb_20160503_models.DeleteCollectionDataRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - collection_data=None, - collection_data_filter=f"ref_doc_id IN {ids_str}", - ) - self._client.delete_collection_data(request) + self.analyticdb_vector.delete_by_ids(ids) def delete_by_metadata_field(self, key: str, value: str) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - request = gpdb_20160503_models.DeleteCollectionDataRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - collection_data=None, - collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", - ) - self._client.delete_collection_data(request) + self.analyticdb_vector.delete_by_metadata_field(key, value) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - score_threshold = kwargs.get("score_threshold") or 0.0 - request = gpdb_20160503_models.QueryCollectionDataRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - include_values=kwargs.pop("include_values", True), - metrics=self.config.metrics, - vector=query_vector, - content=None, - top_k=kwargs.get("top_k", 4), - filter=None, - ) - response = self._client.query_collection_data(request) - documents = [] - for match in response.body.matches.match: - if match.score > score_threshold: - metadata = json.loads(match.metadata.get("metadata_")) - metadata["score"] = match.score - doc = Document( - page_content=match.metadata.get("page_content"), - metadata=metadata, - ) - documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) - return documents + return self.analyticdb_vector.search_by_vector(query_vector) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - score_threshold = float(kwargs.get("score_threshold") or 0.0) - request = gpdb_20160503_models.QueryCollectionDataRequest( - dbinstance_id=self.config.instance_id, - region_id=self.config.region_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - collection=self._collection_name, - include_values=kwargs.pop("include_values", True), - metrics=self.config.metrics, - vector=None, - content=query, - top_k=kwargs.get("top_k", 4), - filter=None, - ) - response = self._client.query_collection_data(request) - documents = [] - for match in response.body.matches.match: - if match.score > score_threshold: - metadata = json.loads(match.metadata.get("metadata_")) - metadata["score"] = match.score - doc = Document( - page_content=match.metadata.get("page_content"), - vector=match.metadata.get("vector"), - metadata=metadata, - ) - documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) - return documents + return self.analyticdb_vector.search_by_full_text(query, **kwargs) def delete(self) -> None: - try: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - - request = gpdb_20160503_models.DeleteCollectionRequest( - collection=self._collection_name, - dbinstance_id=self.config.instance_id, - namespace=self.config.namespace, - namespace_password=self.config.namespace_password, - region_id=self.config.region_id, - ) - self._client.delete_collection(request) - except Exception as e: - raise e + self.analyticdb_vector.delete() class AnalyticdbVectorFactory(AbstractVectorFactory): - def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector: if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() @@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) - # handle optional params - if dify_config.ANALYTICDB_KEY_ID is None: - raise ValueError("ANALYTICDB_KEY_ID should not be None") - if dify_config.ANALYTICDB_KEY_SECRET is None: - raise ValueError("ANALYTICDB_KEY_SECRET should not be None") - if dify_config.ANALYTICDB_REGION_ID is None: - raise ValueError("ANALYTICDB_REGION_ID should not be None") - if dify_config.ANALYTICDB_INSTANCE_ID is None: - raise ValueError("ANALYTICDB_INSTANCE_ID should not be None") - if dify_config.ANALYTICDB_ACCOUNT is None: - raise ValueError("ANALYTICDB_ACCOUNT should not be None") - if dify_config.ANALYTICDB_PASSWORD is None: - raise ValueError("ANALYTICDB_PASSWORD should not be None") - if dify_config.ANALYTICDB_NAMESPACE is None: - raise ValueError("ANALYTICDB_NAMESPACE should not be None") - if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None: - raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None") - return AnalyticdbVector( - collection_name, - AnalyticdbConfig( + if dify_config.ANALYTICDB_HOST is None: + # implemented through OpenAPI + apiConfig = AnalyticdbVectorOpenAPIConfig( access_key_id=dify_config.ANALYTICDB_KEY_ID, access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, region_id=dify_config.ANALYTICDB_REGION_ID, @@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): account_password=dify_config.ANALYTICDB_PASSWORD, namespace=dify_config.ANALYTICDB_NAMESPACE, namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, - ), + ) + sqlConfig = None + else: + # implemented through sql + sqlConfig = AnalyticdbVectorBySqlConfig( + host=dify_config.ANALYTICDB_HOST, + port=dify_config.ANALYTICDB_PORT, + account=dify_config.ANALYTICDB_ACCOUNT, + account_password=dify_config.ANALYTICDB_PASSWORD, + min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, + max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, + namespace=dify_config.ANALYTICDB_NAMESPACE, + ) + apiConfig = None + return AnalyticdbVector( + collection_name, + apiConfig, + sqlConfig, ) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py new file mode 100644 index 0000000000..05e0ebc54f --- /dev/null +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -0,0 +1,309 @@ +import json +from typing import Any + +from pydantic import BaseModel, model_validator + +_import_err_msg = ( + "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " + "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" +) + +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + + +class AnalyticdbVectorOpenAPIConfig(BaseModel): + access_key_id: str + access_key_secret: str + region_id: str + instance_id: str + account: str + account_password: str + namespace: str = "dify" + namespace_password: str = (None,) + metrics: str = "cosine" + read_timeout: int = 60000 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["access_key_id"]: + raise ValueError("config ANALYTICDB_KEY_ID is required") + if not values["access_key_secret"]: + raise ValueError("config ANALYTICDB_KEY_SECRET is required") + if not values["region_id"]: + raise ValueError("config ANALYTICDB_REGION_ID is required") + if not values["instance_id"]: + raise ValueError("config ANALYTICDB_INSTANCE_ID is required") + if not values["account"]: + raise ValueError("config ANALYTICDB_ACCOUNT is required") + if not values["account_password"]: + raise ValueError("config ANALYTICDB_PASSWORD is required") + if not values["namespace_password"]: + raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") + return values + + def to_analyticdb_client_params(self): + return { + "access_key_id": self.access_key_id, + "access_key_secret": self.access_key_secret, + "region_id": self.region_id, + "read_timeout": self.read_timeout, + } + + +class AnalyticdbVectorOpenAPI: + def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): + try: + from alibabacloud_gpdb20160503.client import Client + from alibabacloud_tea_openapi import models as open_api_models + except: + raise ImportError(_import_err_msg) + self._collection_name = collection_name.lower() + self.config = config + self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) + self._client = Client(self._client_config) + self._initialize() + + def _initialize(self) -> None: + cache_key = f"vector_initialize_{self.config.instance_id}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + database_exist_cache_key = f"vector_initialize_{self.config.instance_id}" + if redis_client.get(database_exist_cache_key): + return + self._initialize_vector_database() + self._create_namespace_if_not_exists() + redis_client.set(database_exist_cache_key, 1, ex=3600) + + def _initialize_vector_database(self) -> None: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + request = gpdb_20160503_models.InitVectorDatabaseRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + manager_account=self.config.account, + manager_account_password=self.config.account_password, + ) + self._client.init_vector_database(request) + + def _create_namespace_if_not_exists(self) -> None: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from Tea.exceptions import TeaException + + try: + request = gpdb_20160503_models.DescribeNamespaceRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + manager_account=self.config.account, + manager_account_password=self.config.account_password, + ) + self._client.describe_namespace(request) + except TeaException as e: + if e.statusCode == 404: + request = gpdb_20160503_models.CreateNamespaceRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + manager_account=self.config.account, + manager_account_password=self.config.account_password, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + ) + self._client.create_namespace(request) + else: + raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") + + def _create_collection_if_not_exists(self, embedding_dimension: int): + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from Tea.exceptions import TeaException + + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + try: + request = gpdb_20160503_models.DescribeCollectionRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + ) + self._client.describe_collection(request) + except TeaException as e: + if e.statusCode == 404: + metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}' + full_text_retrieval_fields = "page_content" + request = gpdb_20160503_models.CreateCollectionRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + manager_account=self.config.account, + manager_account_password=self.config.account_password, + namespace=self.config.namespace, + collection=self._collection_name, + dimension=embedding_dimension, + metrics=self.config.metrics, + metadata=metadata, + full_text_retrieval_fields=full_text_retrieval_fields, + ) + self._client.create_collection(request) + else: + raise ValueError(f"failed to create collection {self._collection_name}: {e}") + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] + for doc, embedding in zip(documents, embeddings, strict=True): + metadata = { + "ref_doc_id": doc.metadata["doc_id"], + "page_content": doc.page_content, + "metadata_": json.dumps(doc.metadata), + } + rows.append( + gpdb_20160503_models.UpsertCollectionDataRequestRows( + vector=embedding, + metadata=metadata, + ) + ) + request = gpdb_20160503_models.UpsertCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + rows=rows, + ) + self._client.upsert_collection_data(request) + + def text_exists(self, id: str) -> bool: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + request = gpdb_20160503_models.QueryCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + metrics=self.config.metrics, + include_values=True, + vector=None, + content=None, + top_k=1, + filter=f"ref_doc_id='{id}'", + ) + response = self._client.query_collection_data(request) + return len(response.body.matches.match) > 0 + + def delete_by_ids(self, ids: list[str]) -> None: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + ids_str = ",".join(f"'{id}'" for id in ids) + ids_str = f"({ids_str})" + request = gpdb_20160503_models.DeleteCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + collection_data=None, + collection_data_filter=f"ref_doc_id IN {ids_str}", + ) + self._client.delete_collection_data(request) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + request = gpdb_20160503_models.DeleteCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + collection_data=None, + collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", + ) + self._client.delete_collection_data(request) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + score_threshold = kwargs.get("score_threshold") or 0.0 + request = gpdb_20160503_models.QueryCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + include_values=kwargs.pop("include_values", True), + metrics=self.config.metrics, + vector=query_vector, + content=None, + top_k=kwargs.get("top_k", 4), + filter=None, + ) + response = self._client.query_collection_data(request) + documents = [] + for match in response.body.matches.match: + if match.score > score_threshold: + metadata = json.loads(match.metadata.get("metadata_")) + metadata["score"] = match.score + doc = Document( + page_content=match.metadata.get("page_content"), + vector=match.values.value, + metadata=metadata, + ) + documents.append(doc) + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + return documents + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + score_threshold = float(kwargs.get("score_threshold") or 0.0) + request = gpdb_20160503_models.QueryCollectionDataRequest( + dbinstance_id=self.config.instance_id, + region_id=self.config.region_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + collection=self._collection_name, + include_values=kwargs.pop("include_values", True), + metrics=self.config.metrics, + vector=None, + content=query, + top_k=kwargs.get("top_k", 4), + filter=None, + ) + response = self._client.query_collection_data(request) + documents = [] + for match in response.body.matches.match: + if match.score > score_threshold: + metadata = json.loads(match.metadata.get("metadata_")) + metadata["score"] = match.score + doc = Document( + page_content=match.metadata.get("page_content"), + vector=match.values.value, + metadata=metadata, + ) + documents.append(doc) + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + return documents + + def delete(self) -> None: + try: + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + + request = gpdb_20160503_models.DeleteCollectionRequest( + collection=self._collection_name, + dbinstance_id=self.config.instance_id, + namespace=self.config.namespace, + namespace_password=self.config.namespace_password, + region_id=self.config.region_id, + ) + self._client.delete_collection(request) + except Exception as e: + raise e diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py new file mode 100644 index 0000000000..e474db5cb2 --- /dev/null +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -0,0 +1,245 @@ +import json +import uuid +from contextlib import contextmanager +from typing import Any + +import psycopg2.extras +import psycopg2.pool +from pydantic import BaseModel, model_validator + +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + + +class AnalyticdbVectorBySqlConfig(BaseModel): + host: str + port: int + account: str + account_password: str + min_connection: int + max_connection: int + namespace: str = "dify" + metrics: str = "cosine" + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["host"]: + raise ValueError("config ANALYTICDB_HOST is required") + if not values["port"]: + raise ValueError("config ANALYTICDB_PORT is required") + if not values["account"]: + raise ValueError("config ANALYTICDB_ACCOUNT is required") + if not values["account_password"]: + raise ValueError("config ANALYTICDB_PASSWORD is required") + if not values["min_connection"]: + raise ValueError("config ANALYTICDB_MIN_CONNECTION is required") + if not values["max_connection"]: + raise ValueError("config ANALYTICDB_MAX_CONNECTION is required") + if values["min_connection"] > values["max_connection"]: + raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION") + return values + + +class AnalyticdbVectorBySql: + def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig): + self._collection_name = collection_name.lower() + self.databaseName = "knowledgebase" + self.config = config + self.table_name = f"{self.config.namespace}.{self._collection_name}" + self.pool = None + self._initialize() + if not self.pool: + self.pool = self._create_connection_pool() + + def _initialize(self) -> None: + cache_key = f"vector_initialize_{self.config.host}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + database_exist_cache_key = f"vector_initialize_{self.config.host}" + if redis_client.get(database_exist_cache_key): + return + self._initialize_vector_database() + redis_client.set(database_exist_cache_key, 1, ex=3600) + + def _create_connection_pool(self): + return psycopg2.pool.SimpleConnectionPool( + self.config.min_connection, + self.config.max_connection, + host=self.config.host, + port=self.config.port, + user=self.config.account, + password=self.config.account_password, + database=self.databaseName, + ) + + @contextmanager + def _get_cursor(self): + conn = self.pool.getconn() + cur = conn.cursor() + try: + yield cur + finally: + cur.close() + conn.commit() + self.pool.putconn(conn) + + def _initialize_vector_database(self) -> None: + conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + user=self.config.account, + password=self.config.account_password, + database="postgres", + ) + conn.autocommit = True + cur = conn.cursor() + try: + cur.execute(f"CREATE DATABASE {self.databaseName}") + except Exception as e: + if "already exists" in str(e): + return + raise e + finally: + cur.close() + conn.close() + self.pool = self._create_connection_pool() + with self._get_cursor() as cur: + try: + cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)") + cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple") + except Exception as e: + if "already exists" not in str(e): + raise e + cur.execute( + "CREATE OR REPLACE FUNCTION " + "public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) " + "RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ " + "SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) " + "FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) " + "AS words_only;$function$" + ) + cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") + + def _create_collection_if_not_exists(self, embedding_dimension: int): + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + with self._get_cursor() as cur: + cur.execute( + f"CREATE TABLE IF NOT EXISTS {self.table_name}(" + f"id text PRIMARY KEY," + f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, " + f"to_tsvector TSVECTOR" + f") WITH (fillfactor=70) DISTRIBUTED BY (id);" + ) + if embedding_dimension is not None: + index_name = f"{self._collection_name}_embedding_idx" + cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN") + cur.execute( + f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) " + f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', " + f"pq_enable=0, external_storage=0)" + ) + cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)") + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + values = [] + id_prefix = str(uuid.uuid4()) + "_" + sql = f""" + INSERT INTO {self.table_name} + (id, ref_doc_id, vector, page_content, metadata_, to_tsvector) + VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); + """ + for i, doc in enumerate(documents): + values.append( + ( + id_prefix + str(i), + doc.metadata.get("doc_id", str(uuid.uuid4())), + embeddings[i], + doc.page_content, + json.dumps(doc.metadata), + doc.page_content, + ) + ) + with self._get_cursor() as cur: + psycopg2.extras.execute_batch(cur, sql, values) + + def text_exists(self, id: str) -> bool: + with self._get_cursor() as cur: + cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,)) + return cur.fetchone() is not None + + def delete_by_ids(self, ids: list[str]) -> None: + with self._get_cursor() as cur: + try: + cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),)) + except Exception as e: + if "does not exist" not in str(e): + raise e + + def delete_by_metadata_field(self, key: str, value: str) -> None: + with self._get_cursor() as cur: + try: + cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value)) + except Exception as e: + if "does not exist" not in str(e): + raise e + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + with self._get_cursor() as cur: + query_vector_str = json.dumps(query_vector) + query_vector_str = "{" + query_vector_str[1:-1] + "}" + cur.execute( + f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " + f"t.page_content as page_content, t.metadata_ AS metadata_ " + f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " + f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t", + (query_vector_str,), + ) + documents = [] + for record in cur: + id, vector, score, page_content, metadata = record + if score > score_threshold: + metadata["score"] = score + doc = Document( + page_content=page_content, + vector=vector, + metadata=metadata, + ) + documents.append(doc) + return documents + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 4) + with self._get_cursor() as cur: + cur.execute( + f"""SELECT id, vector, page_content, metadata_, + ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score + FROM {self.table_name} + WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') + ORDER BY score DESC + LIMIT {top_k}""", + (f"'{query}'", f"'{query}'"), + ) + documents = [] + for record in cur: + id, vector, page_content, metadata, score = record + metadata["score"] = score + doc = Document( + page_content=page_content, + vector=vector, + metadata=metadata, + ) + documents.append(doc) + return documents + + def delete(self) -> None: + with self._get_cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py index 970b98edc3..4f44d2ffd6 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -1,27 +1,43 @@ from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis class AnalyticdbVectorTest(AbstractVectorTest): - def __init__(self): + def __init__(self, config_type: str): super().__init__() # Analyticdb requires collection_name length less than 60. # it's ok for normal usage. self.collection_name = self.collection_name.replace("_test", "") - self.vector = AnalyticdbVector( - collection_name=self.collection_name, - config=AnalyticdbConfig( - access_key_id="test_key_id", - access_key_secret="test_key_secret", - region_id="test_region", - instance_id="test_id", - account="test_account", - account_password="test_passwd", - namespace="difytest_namespace", - collection="difytest_collection", - namespace_password="test_passwd", - ), - ) + if config_type == "sql": + self.vector = AnalyticdbVector( + collection_name=self.collection_name, + sql_config=AnalyticdbVectorBySqlConfig( + host="test_host", + port=5432, + account="test_account", + account_password="test_passwd", + namespace="difytest_namespace", + ), + api_config=None, + ) + else: + self.vector = AnalyticdbVector( + collection_name=self.collection_name, + sql_config=None, + api_config=AnalyticdbVectorOpenAPIConfig( + access_key_id="test_key_id", + access_key_secret="test_key_secret", + region_id="test_region", + instance_id="test_id", + account="test_account", + account_password="test_passwd", + namespace="difytest_namespace", + collection="difytest_collection", + namespace_password="test_passwd", + ), + ) def run_all_tests(self): self.vector.delete() @@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest): def test_chroma_vector(setup_mock_redis): - AnalyticdbVectorTest().run_all_tests() + AnalyticdbVectorTest("api").run_all_tests() + AnalyticdbVectorTest("sql").run_all_tests() diff --git a/docker/.env.example b/docker/.env.example index d1f53bb789..9b4a506098 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -450,6 +450,10 @@ ANALYTICDB_ACCOUNT=testaccount ANALYTICDB_PASSWORD=testpassword ANALYTICDB_NAMESPACE=dify ANALYTICDB_NAMESPACE_PASSWORD=difypassword +ANALYTICDB_HOST=gp-test.aliyuncs.com +ANALYTICDB_PORT=5432 +ANALYTICDB_MIN_CONNECTION=1 +ANALYTICDB_MAX_CONNECTION=5 # TiDB vector configurations, only available when VECTOR_STORE is `tidb` TIDB_VECTOR_HOST=tidb diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 9f33720db4..b6caff90d9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -185,6 +185,10 @@ x-shared-env: &shared-api-worker-env ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-} ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-} + ANALYTICDB_HOST: ${ANALYTICDB_HOST:-} + ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} + ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} + ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}