From 875249eb00ddfe5de669de6b7e7317d247a08509 Mon Sep 17 00:00:00 2001 From: LiuVaayne <10231735+vaayne@users.noreply.github.com> Date: Fri, 10 May 2024 17:20:30 +0800 Subject: [PATCH] Feat/vector db pgvector (#3879) --- .github/workflows/api-tests.yml | 4 +- api/.env.example | 9 +- api/commands.py | 8 + api/config.py | 9 +- api/controllers/console/datasets/datasets.py | 9 +- .../rag/datasource/vdb/pgvector/__init__.py | 0 .../rag/datasource/vdb/pgvector/pgvector.py | 169 ++++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 23 +++ api/requirements.txt | 1 + .../vdb/pgvector/__init__.py | 0 .../vdb/pgvector/test_pgvector.py | 30 ++++ docker/docker-compose.pgvector.yaml | 24 +++ docker/docker-compose.yaml | 39 +++- 13 files changed, 316 insertions(+), 9 deletions(-) create mode 100644 api/core/rag/datasource/vdb/pgvector/__init__.py create mode 100644 api/core/rag/datasource/vdb/pgvector/pgvector.py create mode 100644 api/tests/integration_tests/vdb/pgvector/__init__.py create mode 100644 api/tests/integration_tests/vdb/pgvector/test_pgvector.py create mode 100644 docker/docker-compose.pgvector.yaml diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index a0407de843..7d24b15bdf 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -50,7 +50,7 @@ jobs: - name: Run Workflow run: dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | @@ -58,6 +58,7 @@ jobs: docker/docker-compose.qdrant.yaml docker/docker-compose.milvus.yaml docker/docker-compose.pgvecto-rs.yaml + docker/docker-compose.pgvector.yaml services: | weaviate qdrant @@ -65,6 +66,7 @@ jobs: minio milvus-standalone pgvecto-rs + pgvector - name: Test Vector Stores run: dev/pytest/pytest_vdb.sh diff --git a/api/.env.example b/api/.env.example index 01326a0cc8..e0f87d471a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -65,7 +65,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs +# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector VECTOR_STORE=weaviate # Weaviate configuration @@ -102,6 +102,13 @@ PGVECTO_RS_USER=postgres PGVECTO_RS_PASSWORD=difyai123456 PGVECTO_RS_DATABASE=postgres +# PGVector configuration +PGVECTOR_HOST=127.0.0.1 +PGVECTOR_PORT=5433 +PGVECTOR_USER=postgres +PGVECTOR_PASSWORD=postgres +PGVECTOR_DATABASE=postgres + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/commands.py b/api/commands.py index b82f7ac3f8..75f2491421 100644 --- a/api/commands.py +++ b/api/commands.py @@ -305,6 +305,14 @@ def migrate_knowledge_vector_database(): "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == "pgvector": + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": 'pgvector', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/config.py b/api/config.py index 81fb866d27..4dcd44237a 100644 --- a/api/config.py +++ b/api/config.py @@ -222,7 +222,7 @@ class Config: # ------------------------ # Vector Store Configurations. - # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt + # Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector # ------------------------ self.VECTOR_STORE = get_env('VECTOR_STORE') self.KEYWORD_STORE = get_env('KEYWORD_STORE') @@ -261,6 +261,13 @@ class Config: self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD') self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE') + # pgvector settings + self.PGVECTOR_HOST = get_env('PGVECTOR_HOST') + self.PGVECTOR_PORT = get_env('PGVECTOR_PORT') + self.PGVECTOR_USER = get_env('PGVECTOR_USER') + self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD') + self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE') + # ------------------------ # Mail Configurations. # ------------------------ diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 40ded54120..30dc6ac845 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -476,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource): @account_initialization_required def get(self): vector_type = current_app.config['VECTOR_STORE'] - if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt': + if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}: return { 'retrieval_method': [ 'semantic_search' ] } - elif vector_type == 'qdrant' or vector_type == 'weaviate': + elif vector_type in {"qdrant", "weaviate"}: return { 'retrieval_method': [ 'semantic_search', 'full_text_search', 'hybrid_search' @@ -497,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource): @login_required @account_initialization_required def get(self, vector_type): - - if vector_type == 'milvus' or vector_type == 'relyt': + if vector_type in {'milvus', 'relyt', 'pgvector'}: return { 'retrieval_method': [ 'semantic_search' ] } - elif vector_type == 'qdrant' or vector_type == 'weaviate': + elif vector_type in {'qdrant', 'weaviate'}: return { 'retrieval_method': [ 'semantic_search', 'full_text_search', 'hybrid_search' diff --git a/api/core/rag/datasource/vdb/pgvector/__init__.py b/api/core/rag/datasource/vdb/pgvector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py new file mode 100644 index 0000000000..22cf790bfa --- /dev/null +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -0,0 +1,169 @@ +import json +import uuid +from contextlib import contextmanager +from typing import Any + +import psycopg2.extras +import psycopg2.pool +from pydantic import BaseModel, root_validator + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + + +class PGVectorConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values["host"]: + raise ValueError("config PGVECTOR_HOST is required") + if not values["port"]: + raise ValueError("config PGVECTOR_PORT is required") + if not values["user"]: + raise ValueError("config PGVECTOR_USER is required") + if not values["password"]: + raise ValueError("config PGVECTOR_PASSWORD is required") + if not values["database"]: + raise ValueError("config PGVECTOR_DATABASE is required") + return values + + +SQL_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS {table_name} ( + id UUID PRIMARY KEY, + text TEXT NOT NULL, + meta JSONB NOT NULL, + embedding vector({dimension}) NOT NULL +) using heap; +""" + + +class PGVector(BaseVector): + def __init__(self, collection_name: str, config: PGVectorConfig): + super().__init__(collection_name) + self.pool = self._create_connection_pool(config) + self.table_name = f"embedding_{collection_name}" + + def get_type(self) -> str: + return "pgvector" + + def _create_connection_pool(self, config: PGVectorConfig): + return psycopg2.pool.SimpleConnectionPool( + 1, + 5, + host=config.host, + port=config.port, + user=config.user, + password=config.password, + database=config.database, + ) + + @contextmanager + def _get_cursor(self): + conn = self.pool.getconn() + cur = conn.cursor() + try: + yield cur + finally: + cur.close() + conn.commit() + self.pool.putconn(conn) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + return self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + values = [] + pks = [] + for i, doc in enumerate(documents): + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) + ) + with self._get_cursor() as cur: + psycopg2.extras.execute_values( + cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values + ) + return pks + + def text_exists(self, id: str) -> bool: + with self._get_cursor() as cur: + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,)) + return cur.fetchone() is not None + + def get_by_ids(self, ids: list[str]) -> list[Document]: + with self._get_cursor() as cur: + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + docs = [] + for record in cur: + docs.append(Document(page_content=record[1], metadata=record[0])) + return docs + + def delete_by_ids(self, ids: list[str]) -> None: + with self._get_cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + with self._get_cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """ + Search the nearest neighbors to a vector. + + :param query_vector: The input vector to search for similar items. + :param top_k: The number of nearest neighbors to return, default is 5. + :return: List of Documents that are nearest to the query vector. + """ + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cur: + cur.execute( + f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}", + (json.dumps(query_vector),), + ) + docs = [] + score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + for record in cur: + metadata, text, distance = record + score = 1 - distance + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # do not support bm25 search + return [] + + def delete(self) -> None: + with self._get_cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + + def _create_collection(self, dimension: int): + cache_key = f"vector_indexing_{self._collection_name}" + lock_name = f"{cache_key}_lock" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + + with self._get_cursor() as cur: + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) + # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing + redis_client.set(collection_exist_cache_key, 1, ex=3600) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 2405d16b1d..82ba6139e1 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -164,6 +164,29 @@ class Vector: ), dim=dim ) + elif vector_type == "pgvector": + from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig + + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = self._dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": "pgvector", + "vector_store": {"class_prefix": collection_name}} + self._dataset.index_struct = json.dumps(index_struct_dict) + return PGVector( + collection_name=collection_name, + config=PGVectorConfig( + host=config.get("PGVECTOR_HOST"), + port=config.get("PGVECTOR_PORT"), + user=config.get("PGVECTOR_USER"), + password=config.get("PGVECTOR_PASSWORD"), + database=config.get("PGVECTOR_DATABASE"), + ), + ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/requirements.txt b/api/requirements.txt index e2c430c9d6..6d08202527 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -83,3 +83,4 @@ pydantic~=1.10.0 pgvecto-rs==0.1.4 firecrawl-py==0.0.5 oss2==2.15.0 +pgvector==0.2.5 diff --git a/api/tests/integration_tests/vdb/pgvector/__init__.py b/api/tests/integration_tests/vdb/pgvector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py new file mode 100644 index 0000000000..915bb5e837 --- /dev/null +++ b/api/tests/integration_tests/vdb/pgvector/test_pgvector.py @@ -0,0 +1,30 @@ +from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig +from core.rag.models.document import Document +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class TestPGVector(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = PGVector( + collection_name=self.collection_name, + config=PGVectorConfig( + host="localhost", + port=5433, + user="postgres", + password="difyai123456", + database="dify", + ), + ) + + def search_by_full_text(self): + hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_pgvector(setup_mock_redis): + TestPGVector().run_all_tests() diff --git a/docker/docker-compose.pgvector.yaml b/docker/docker-compose.pgvector.yaml new file mode 100644 index 0000000000..b584880abf --- /dev/null +++ b/docker/docker-compose.pgvector.yaml @@ -0,0 +1,24 @@ +version: '3' +services: + # Qdrant vector store. + pgvector: + image: pgvector/pgvector:pg16 + restart: always + environment: + PGUSER: postgres + # The password for the default postgres user. + POSTGRES_PASSWORD: difyai123456 + # The name of the default postgres database. + POSTGRES_DB: dify + # postgres data directory + PGDATA: /var/lib/postgresql/data/pgdata + volumes: + - ./volumes/pgvector/data:/var/lib/postgresql/data + # uncomment to expose db(postgresql) port to host + ports: + - "5433:5432" + healthcheck: + test: [ "CMD", "pg_isready" ] + interval: 1s + timeout: 3s + retries: 30 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 432383b76d..b2a3353641 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -122,6 +122,12 @@ services: RELYT_USER: postgres RELYT_PASSWORD: difyai123456 RELYT_DATABASE: postgres + # pgvector configurations + PGVECTOR_HOST: pgvector + PGVECTOR_PORT: 5432 + PGVECTOR_USER: postgres + PGVECTOR_PASSWORD: difyai123456 + PGVECTOR_DATABASE: dify # Mail configuration, support: resend, smtp MAIL_TYPE: '' # default send from email address, if not specified @@ -211,7 +217,7 @@ services: AZURE_BLOB_ACCOUNT_KEY: 'difyai' AZURE_BLOB_CONTAINER_NAME: 'difyai-container' AZURE_BLOB_ACCOUNT_URL: 'https://.blob.core.windows.net' - # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`. + # The type of vector store to use. Supported values are `weaviate`, `qdrant`, `milvus`, `relyt`, `pgvector`. VECTOR_STORE: weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. WEAVIATE_ENDPOINT: http://weaviate:8080 @@ -251,6 +257,12 @@ services: RELYT_USER: postgres RELYT_PASSWORD: difyai123456 RELYT_DATABASE: postgres + # pgvector configurations + PGVECTOR_HOST: pgvector + PGVECTOR_PORT: 5432 + PGVECTOR_USER: postgres + PGVECTOR_PASSWORD: difyai123456 + PGVECTOR_DATABASE: dify # Notion import configuration, support public and internal NOTION_INTEGRATION_TYPE: public NOTION_CLIENT_SECRET: you-client-secret @@ -374,6 +386,31 @@ services: # # - "6333:6333" # # - "6334:6334" + # The pgvector vector database. + # Uncomment to use qdrant as vector store. + # pgvector: + # image: pgvector/pgvector:pg16 + # restart: always + # environment: + # PGUSER: postgres + # # The password for the default postgres user. + # POSTGRES_PASSWORD: difyai123456 + # # The name of the default postgres database. + # POSTGRES_DB: dify + # # postgres data directory + # PGDATA: /var/lib/postgresql/data/pgdata + # volumes: + # - ./volumes/pgvector/data:/var/lib/postgresql/data + # # uncomment to expose db(postgresql) port to host + # # ports: + # # - "5433:5432" + # healthcheck: + # test: [ "CMD", "pg_isready" ] + # interval: 1s + # timeout: 3s + # retries: 30 + + # The nginx reverse proxy. # used for reverse proxying the API service and Web service. nginx: