From bdad993901c2a4887112944995c0bdccbedd62a1 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Sat, 8 Jun 2024 22:29:24 +0800 Subject: [PATCH] improve: generalize vector factory classes and vector type (#5033) --- api/commands.py | 21 +- api/controllers/console/datasets/datasets.py | 61 ++--- .../datasource/vdb/milvus/milvus_vector.py | 36 ++- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 33 ++- .../rag/datasource/vdb/pgvector/pgvector.py | 31 ++- .../datasource/vdb/qdrant/qdrant_vector.py | 46 +++- .../rag/datasource/vdb/relyt/relyt_vector.py | 42 +++- .../datasource/vdb/tidb_vector/tidb_vector.py | 33 +++ api/core/rag/datasource/vdb/vector_base.py | 4 + api/core/rag/datasource/vdb/vector_factory.py | 229 ++++-------------- api/core/rag/datasource/vdb/vector_type.py | 11 + .../vdb/weaviate/weaviate_vector.py | 29 ++- 12 files changed, 343 insertions(+), 233 deletions(-) create mode 100644 api/core/rag/datasource/vdb/vector_type.py diff --git a/api/commands.py b/api/commands.py index 186b97c3fa..da3f7416d4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound from constants.languages import languages from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document from extensions.ext_database import db from libs.helper import email as email_validate @@ -266,15 +267,15 @@ def migrate_knowledge_vector_database(): skipped_count = skipped_count + 1 continue collection_name = '' - if vector_type == "weaviate": + if vector_type == VectorType.WEAVIATE: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'weaviate', + "type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == "qdrant": + elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ @@ -287,20 +288,20 @@ def migrate_knowledge_vector_database(): dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'qdrant', + "type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == "milvus": + elif vector_type == VectorType.MILVUS: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'milvus', + "type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == "relyt": + elif vector_type == VectorType.RELYT: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { @@ -308,16 +309,16 @@ def migrate_knowledge_vector_database(): "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == "pgvector": + elif vector_type == VectorType.PGVECTOR: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = { - "type": 'pgvector', + "type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) else: - raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + raise ValueError(f"Vector store {vector_type} is not supported.") vector = Vector(dataset) click.echo(f"Start to migrate dataset {dataset.id}.") diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 72c4c09055..49e50caf70 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from fields.app_fields import related_app_list @@ -476,20 +477,22 @@ class DatasetRetrievalSettingApi(Resource): @account_initialization_required def get(self): vector_type = current_app.config['VECTOR_STORE'] - if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}: - return { - 'retrieval_method': [ - 'semantic_search' - ] - } - elif vector_type in {"qdrant", "weaviate"}: - return { - 'retrieval_method': [ - 'semantic_search', 'full_text_search', 'hybrid_search' - ] - } - else: - raise ValueError("Unsupported vector db type.") + + match vector_type: + case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR: + return { + 'retrieval_method': [ + 'semantic_search' + ] + } + case VectorType.QDRANT | VectorType.WEAVIATE: + return { + 'retrieval_method': [ + 'semantic_search', 'full_text_search', 'hybrid_search' + ] + } + case _: + raise ValueError(f"Unsupported vector db type {vector_type}.") class DatasetRetrievalSettingMockApi(Resource): @@ -497,20 +500,22 @@ class DatasetRetrievalSettingMockApi(Resource): @login_required @account_initialization_required def get(self, vector_type): - if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}: - return { - 'retrieval_method': [ - 'semantic_search' - ] - } - elif vector_type in {'qdrant', 'weaviate'}: - return { - 'retrieval_method': [ - 'semantic_search', 'full_text_search', 'hybrid_search' - ] - } - else: - raise ValueError("Unsupported vector db type.") + match vector_type: + case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR: + return { + 'retrieval_method': [ + 'semantic_search' + ] + } + case VectorType.QDRANT | VectorType.WEAVIATE: + return { + 'retrieval_method': [ + 'semantic_search', 'full_text_search', 'hybrid_search' + ] + } + case _: + raise ValueError(f"Unsupported vector db type {vector_type}.") + class DatasetErrorDocs(Resource): @setup_required diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 0586e279d3..d77cf26d25 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,14 +1,20 @@ +import json import logging from typing import Any, Optional from uuid import uuid4 +from flask import current_app from pydantic import BaseModel, root_validator from pymilvus import MilvusClient, MilvusException, connections +from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document from extensions.ext_redis import redis_client +from models.dataset import Dataset logger = logging.getLogger(__name__) @@ -55,7 +61,7 @@ class MilvusVector(BaseVector): self._fields = [] def get_type(self) -> str: - return 'milvus' + return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_params = { @@ -254,10 +260,36 @@ class MilvusVector(BaseVector): schema=schema, index_param=index_params, consistency_level=self._consistency_level) redis_client.set(collection_exist_cache_key, 1, ex=3600) + def _init_client(self, config) -> MilvusClient: if config.secure: uri = "https://" + str(config.host) + ":" + str(config.port) else: uri = "http://" + str(config.host) + ":" + str(config.port) - client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database) + client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database) return client + + +class MilvusVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + + config = current_app.config + return MilvusVector( + collection_name=collection_name, + config=MilvusConfig( + host=config.get('MILVUS_HOST'), + port=config.get('MILVUS_PORT'), + user=config.get('MILVUS_USER'), + password=config.get('MILVUS_PASSWORD'), + secure=config.get('MILVUS_SECURE'), + database=config.get('MILVUS_DATABASE'), + ) + ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 3842aee6c7..80a2c3f82b 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -1,7 +1,9 @@ +import json import logging from typing import Any from uuid import UUID, uuid4 +from flask import current_app from numpy import ndarray from pgvecto_rs.sqlalchemy import Vector from pydantic import BaseModel, root_validator @@ -10,10 +12,14 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column +from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document from extensions.ext_redis import redis_client +from models.dataset import Dataset logger = logging.getLogger(__name__) @@ -67,7 +73,7 @@ class PGVectoRS(BaseVector): self._distance_op = "<=>" def get_type(self) -> str: - return 'pgvecto-rs' + return VectorType.PGVECTO_RS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self.create_collection(len(embeddings[0])) @@ -222,3 +228,28 @@ class PGVectoRS(BaseVector): # docs.append(doc) # return docs return [] + + +class PGVectoRSFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dim = len(embeddings.embed_query("pgvecto_rs")) + config = current_app.config + return PGVectoRS( + collection_name=collection_name, + config=PgvectoRSConfig( + host=config.get('PGVECTO_RS_HOST'), + port=config.get('PGVECTO_RS_PORT'), + user=config.get('PGVECTO_RS_USER'), + password=config.get('PGVECTO_RS_PASSWORD'), + database=config.get('PGVECTO_RS_DATABASE'), + ), + dim=dim + ) \ No newline at end of file diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 22cf790bfa..c9a1508ab2 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -5,11 +5,16 @@ from typing import Any import psycopg2.extras import psycopg2.pool +from flask import current_app from pydantic import BaseModel, root_validator +from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document from extensions.ext_redis import redis_client +from models.dataset import Dataset class PGVectorConfig(BaseModel): @@ -51,7 +56,7 @@ class PGVector(BaseVector): self.table_name = f"embedding_{collection_name}" def get_type(self) -> str: - return "pgvector" + return VectorType.PGVECTOR def _create_connection_pool(self, config: PGVectorConfig): return psycopg2.pool.SimpleConnectionPool( @@ -167,3 +172,27 @@ class PGVector(BaseVector): cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing redis_client.set(collection_exist_cache_key, 1, ex=3600) + + +class PGVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) + + config = current_app.config + return PGVector( + collection_name=collection_name, + config=PGVectorConfig( + host=config.get("PGVECTOR_HOST"), + port=config.get("PGVECTOR_PORT"), + user=config.get("PGVECTOR_USER"), + password=config.get("PGVECTOR_PASSWORD"), + database=config.get("PGVECTOR_DATABASE"), + ), + ) \ No newline at end of file diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 7a92314542..6a77c135ff 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -1,3 +1,4 @@ +import json import os import uuid from collections.abc import Generator, Iterable, Sequence @@ -5,6 +6,7 @@ from itertools import islice from typing import TYPE_CHECKING, Any, Optional, Union, cast import qdrant_client +from flask import current_app from pydantic import BaseModel from qdrant_client.http import models as rest from qdrant_client.http.models import ( @@ -17,10 +19,15 @@ from qdrant_client.http.models import ( ) from qdrant_client.local.qdrant_local import QdrantLocal +from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.dataset import Dataset, DatasetCollectionBinding if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -69,7 +76,7 @@ class QdrantVector(BaseVector): self._group_id = group_id def get_type(self) -> str: - return 'qdrant' + return VectorType.QDRANT def to_index_struct(self) -> dict: return { @@ -408,3 +415,40 @@ class QdrantVector(BaseVector): page_content=scored_point.payload.get(content_payload_key), metadata=scored_point.payload.get(metadata_payload_key) or {}, ) + + +class QdrantVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: + if dataset.collection_binding_id: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ + one_or_none() + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError('Dataset Collection Bindings is not exist!') + else: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + + if not dataset.index_struct_dict: + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) + + config = current_app.config + return QdrantVector( + collection_name=collection_name, + group_id=dataset.id, + config=QdrantConfig( + endpoint=config.get('QDRANT_URL'), + api_key=config.get('QDRANT_API_KEY'), + root_path=config.root_path, + timeout=config.get('QDRANT_CLIENT_TIMEOUT'), + grpc_port=config.get('QDRANT_GRPC_PORT'), + prefer_grpc=config.get('QDRANT_GRPC_ENABLED') + ) + ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index ee88d9fa29..5ccb24f57f 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -1,12 +1,19 @@ +import json import uuid from typing import Any, Optional +from flask import current_app from pydantic import BaseModel, root_validator from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from models.dataset import Dataset + try: from sqlalchemy.orm import declarative_base except ImportError: @@ -53,7 +60,7 @@ class RelytVector(BaseVector): self._group_id = group_id def get_type(self) -> str: - return 'relyt' + return VectorType.RELYT def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_params = {} @@ -240,10 +247,10 @@ class RelytVector(BaseVector): return docs def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided try: @@ -298,3 +305,28 @@ class RelytVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # milvus/zilliz/relyt doesn't support bm25 search return [] + + +class RelytVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.RELYT, collection_name)) + + config = current_app.config + return RelytVector( + collection_name=collection_name, + config=RelytConfig( + host=config.get('RELYT_HOST'), + port=config.get('RELYT_PORT'), + user=config.get('RELYT_USER'), + password=config.get('RELYT_PASSWORD'), + database=config.get('RELYT_DATABASE'), + ), + group_id=dataset.id + ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 107d17bb47..6564c565d1 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -3,14 +3,19 @@ import logging from typing import Any import sqlalchemy +from flask import current_app from pydantic import BaseModel, root_validator from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base +from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document from extensions.ext_redis import redis_client +from models.dataset import Dataset logger = logging.getLogger(__name__) @@ -39,6 +44,9 @@ class TiDBVectorConfig(BaseModel): class TiDBVector(BaseVector): + def get_type(self) -> str: + return VectorType.TIDB_VECTOR + def _table(self, dim: int) -> Table: from tidb_vector.sqlalchemy import VectorType return Table( @@ -214,3 +222,28 @@ class TiDBVector(BaseVector): with Session(self._engine) as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) session.commit() + + +class TiDBVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: + + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) + + config = current_app.config + return TiDBVector( + collection_name=collection_name, + config=TiDBVectorConfig( + host=config.get('TIDB_VECTOR_HOST'), + port=config.get('TIDB_VECTOR_PORT'), + user=config.get('TIDB_VECTOR_USER'), + password=config.get('TIDB_VECTOR_PASSWORD'), + database=config.get('TIDB_VECTOR_DATABASE'), + ), + ) \ No newline at end of file diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index c5212aa8f2..9b414e4e12 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -11,6 +11,10 @@ class BaseVector(ABC): def __init__(self, collection_name: str): self._collection_name = collection_name + @abstractmethod + def get_type(self) -> str: + raise NotImplementedError + @abstractmethod def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): raise NotImplementedError diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b500b37d60..852dc51a3a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,4 +1,4 @@ -import json +from abc import ABC, abstractmethod from typing import Any from flask import current_app @@ -8,9 +8,23 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document -from extensions.ext_database import db -from models.dataset import Dataset, DatasetCollectionBinding +from models.dataset import Dataset + + +class AbstractVectorFactory(ABC): + @abstractmethod + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: + raise NotImplementedError + + @staticmethod + def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: + index_struct_dict = { + "type": vector_type, + "vector_store": {"class_prefix": collection_name} + } + return index_struct_dict class Vector: @@ -32,188 +46,35 @@ class Vector: if not vector_type: raise ValueError("Vector store must be specified.") - if vector_type == "weaviate": - from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] - collection_name = class_prefix - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'weaviate', - "vector_store": {"class_prefix": collection_name} - } - self._dataset.index_struct = json.dumps(index_struct_dict) - return WeaviateVector( - collection_name=collection_name, - config=WeaviateConfig( - endpoint=config.get('WEAVIATE_ENDPOINT'), - api_key=config.get('WEAVIATE_API_KEY'), - batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) - ), - attributes=self._attributes - ) - elif vector_type == "qdrant": - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector - if self._dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \ - one_or_none() - if dataset_collection_binding: - collection_name = dataset_collection_binding.collection_name - else: - raise ValueError('Dataset Collection Bindings is not exist!') - else: - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] - collection_name = class_prefix - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) + vector_factory_cls = self.get_vector_factory(vector_type) + return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings) - if not self._dataset.index_struct_dict: - index_struct_dict = { - "type": 'qdrant', - "vector_store": {"class_prefix": collection_name} - } - self._dataset.index_struct = json.dumps(index_struct_dict) - - return QdrantVector( - collection_name=collection_name, - group_id=self._dataset.id, - config=QdrantConfig( - endpoint=config.get('QDRANT_URL'), - api_key=config.get('QDRANT_API_KEY'), - root_path=current_app.root_path, - timeout=config.get('QDRANT_CLIENT_TIMEOUT'), - grpc_port=config.get('QDRANT_GRPC_PORT'), - prefer_grpc=config.get('QDRANT_GRPC_ENABLED') - ) - ) - elif vector_type == "milvus": - from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] - collection_name = class_prefix - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'milvus', - "vector_store": {"class_prefix": collection_name} - } - self._dataset.index_struct = json.dumps(index_struct_dict) - return MilvusVector( - collection_name=collection_name, - config=MilvusConfig( - host=config.get('MILVUS_HOST'), - port=config.get('MILVUS_PORT'), - user=config.get('MILVUS_USER'), - password=config.get('MILVUS_PASSWORD'), - secure=config.get('MILVUS_SECURE'), - database=config.get('MILVUS_DATABASE'), - ) - ) - elif vector_type == "relyt": - from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] - collection_name = class_prefix - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": 'relyt', - "vector_store": {"class_prefix": collection_name} - } - self._dataset.index_struct = json.dumps(index_struct_dict) - return RelytVector( - collection_name=collection_name, - config=RelytConfig( - host=config.get('RELYT_HOST'), - port=config.get('RELYT_PORT'), - user=config.get('RELYT_USER'), - password=config.get('RELYT_PASSWORD'), - database=config.get('RELYT_DATABASE'), - ), - group_id=self._dataset.id - ) - elif vector_type == "pgvecto_rs": - from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] - collection_name = class_prefix.lower() - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": 'pgvecto_rs', - "vector_store": {"class_prefix": collection_name} - } - self._dataset.index_struct = json.dumps(index_struct_dict) - dim = len(self._embeddings.embed_query("pgvecto_rs")) - return PGVectoRS( - collection_name=collection_name, - config=PgvectoRSConfig( - host=config.get('PGVECTO_RS_HOST'), - port=config.get('PGVECTO_RS_PORT'), - user=config.get('PGVECTO_RS_USER'), - password=config.get('PGVECTO_RS_PASSWORD'), - database=config.get('PGVECTO_RS_DATABASE'), - ), - dim=dim - ) - elif vector_type == "pgvector": - from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig - - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"] - collection_name = class_prefix - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": "pgvector", - "vector_store": {"class_prefix": collection_name}} - self._dataset.index_struct = json.dumps(index_struct_dict) - return PGVector( - collection_name=collection_name, - config=PGVectorConfig( - host=config.get("PGVECTOR_HOST"), - port=config.get("PGVECTOR_PORT"), - user=config.get("PGVECTOR_USER"), - password=config.get("PGVECTOR_PASSWORD"), - database=config.get("PGVECTOR_DATABASE"), - ), - ) - elif vector_type == "tidb_vector": - from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig - - if self._dataset.index_struct_dict: - class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] - collection_name = class_prefix.lower() - else: - dataset_id = self._dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": 'tidb_vector', - "vector_store": {"class_prefix": collection_name} - } - self._dataset.index_struct = json.dumps(index_struct_dict) - - return TiDBVector( - collection_name=collection_name, - config=TiDBVectorConfig( - host=config.get('TIDB_VECTOR_HOST'), - port=config.get('TIDB_VECTOR_PORT'), - user=config.get('TIDB_VECTOR_USER'), - password=config.get('TIDB_VECTOR_PASSWORD'), - database=config.get('TIDB_VECTOR_DATABASE'), - ), - ) - else: - raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + @staticmethod + def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: + match vector_type: + case VectorType.MILVUS: + from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory + return MilvusVectorFactory + case VectorType.PGVECTOR: + from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory + return PGVectorFactory + case VectorType.PGVECTO_RS: + from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory + return PGVectoRSFactory + case VectorType.QDRANT: + from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory + return QdrantVectorFactory + case VectorType.RELYT: + from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory + return RelytVectorFactory + case VectorType.TIDB_VECTOR: + from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory + return TiDBVectorFactory + case VectorType.WEAVIATE: + from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory + return WeaviateVectorFactory + case _: + raise ValueError(f"Vector store {vector_type} is not supported.") def create(self, texts: list = None, **kwargs): if texts: diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py new file mode 100644 index 0000000000..f2e9b9b8d4 --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class VectorType(str, Enum): + MILVUS = 'milvus' + PGVECTOR = 'pgvector' + PGVECTO_RS = 'pgvecto-rs' + QDRANT = 'qdrant' + RELYT = 'relyt' + TIDB_VECTOR = 'tidb_vector' + WEAVIATE = 'weaviate' diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 93c7480f8b..73f9c580aa 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -1,12 +1,17 @@ import datetime +import json from typing import Any, Optional import requests import weaviate +from flask import current_app from pydantic import BaseModel, root_validator +from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -59,7 +64,7 @@ class WeaviateVector(BaseVector): return client def get_type(self) -> str: - return 'weaviate' + return VectorType.WEAVIATE def get_collection_name(self, dataset: Dataset) -> str: if dataset.index_struct_dict: @@ -255,3 +260,25 @@ class WeaviateVector(BaseVector): if isinstance(value, datetime.datetime): return value.isoformat() return value + + +class WeaviateVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + + return WeaviateVector( + collection_name=collection_name, + config=WeaviateConfig( + endpoint=current_app.config.get('WEAVIATE_ENDPOINT'), + api_key=current_app.config.get('WEAVIATE_API_KEY'), + batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE')) + ), + attributes=attributes + )