improve: generalize vector factory classes and vector type (#5033)

This commit is contained in:
Bowen Liang 2024-06-08 22:29:24 +08:00 committed by GitHub
parent 3b62ab564a
commit bdad993901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 343 additions and 233 deletions

View File

@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
from constants.languages import languages from constants.languages import languages
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email as email_validate
@ -266,15 +267,15 @@ def migrate_knowledge_vector_database():
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
continue continue
collection_name = '' collection_name = ''
if vector_type == "weaviate": if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
"type": 'weaviate', "type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "qdrant": elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
@ -287,20 +288,20 @@ def migrate_knowledge_vector_database():
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
"type": 'qdrant', "type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus": elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
"type": 'milvus', "type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "relyt": elif vector_type == VectorType.RELYT:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
@ -308,16 +309,16 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "pgvector": elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = { index_struct_dict = {
"type": 'pgvector', "type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name} "vector_store": {"class_prefix": collection_name}
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
else: else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")
vector = Vector(dataset) vector = Vector(dataset)
click.echo(f"Start to migrate dataset {dataset.id}.") click.echo(f"Start to migrate dataset {dataset.id}.")

View File

@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
@ -476,20 +477,22 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
vector_type = current_app.config['VECTOR_STORE'] vector_type = current_app.config['VECTOR_STORE']
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
return { match vector_type:
'retrieval_method': [ case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
'semantic_search' return {
] 'retrieval_method': [
} 'semantic_search'
elif vector_type in {"qdrant", "weaviate"}: ]
return { }
'retrieval_method': [ case VectorType.QDRANT | VectorType.WEAVIATE:
'semantic_search', 'full_text_search', 'hybrid_search' return {
] 'retrieval_method': [
} 'semantic_search', 'full_text_search', 'hybrid_search'
else: ]
raise ValueError("Unsupported vector db type.") }
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetRetrievalSettingMockApi(Resource): class DatasetRetrievalSettingMockApi(Resource):
@ -497,20 +500,22 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, vector_type): def get(self, vector_type):
if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}: match vector_type:
return { case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
'retrieval_method': [ return {
'semantic_search' 'retrieval_method': [
] 'semantic_search'
} ]
elif vector_type in {'qdrant', 'weaviate'}: }
return { case VectorType.QDRANT | VectorType.WEAVIATE:
'retrieval_method': [ return {
'semantic_search', 'full_text_search', 'hybrid_search' 'retrieval_method': [
] 'semantic_search', 'full_text_search', 'hybrid_search'
} ]
else: }
raise ValueError("Unsupported vector db type.") case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource): class DatasetErrorDocs(Resource):
@setup_required @setup_required

View File

@ -1,14 +1,20 @@
import json
import logging import logging
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4
from flask import current_app
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,7 +61,7 @@ class MilvusVector(BaseVector):
self._fields = [] self._fields = []
def get_type(self) -> str: def get_type(self) -> str:
return 'milvus' return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = { index_params = {
@ -254,10 +260,36 @@ class MilvusVector(BaseVector):
schema=schema, index_param=index_params, schema=schema, index_param=index_params,
consistency_level=self._consistency_level) consistency_level=self._consistency_level)
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient: def _init_client(self, config) -> MilvusClient:
if config.secure: if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port) uri = "https://" + str(config.host) + ":" + str(config.port)
else: else:
uri = "http://" + str(config.host) + ":" + str(config.port) uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database) client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
return client return client
class MilvusVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
config = current_app.config
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
)
)

View File

@ -1,7 +1,9 @@
import json
import logging import logging
from typing import Any from typing import Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from flask import current_app
from numpy import ndarray from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -10,10 +12,14 @@ from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,7 +73,7 @@ class PGVectoRS(BaseVector):
self._distance_op = "<=>" self._distance_op = "<=>"
def get_type(self) -> str: def get_type(self) -> str:
return 'pgvecto-rs' return VectorType.PGVECTO_RS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0])) self.create_collection(len(embeddings[0]))
@ -222,3 +228,28 @@ class PGVectoRS(BaseVector):
# docs.append(doc) # docs.append(doc)
# return docs # return docs
return [] return []
class PGVectoRSFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
config = current_app.config
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
),
dim=dim
)

View File

@ -5,11 +5,16 @@ from typing import Any
import psycopg2.extras import psycopg2.extras
import psycopg2.pool import psycopg2.pool
from flask import current_app
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset
class PGVectorConfig(BaseModel): class PGVectorConfig(BaseModel):
@ -51,7 +56,7 @@ class PGVector(BaseVector):
self.table_name = f"embedding_{collection_name}" self.table_name = f"embedding_{collection_name}"
def get_type(self) -> str: def get_type(self) -> str:
return "pgvector" return VectorType.PGVECTOR
def _create_connection_pool(self, config: PGVectorConfig): def _create_connection_pool(self, config: PGVectorConfig):
return psycopg2.pool.SimpleConnectionPool( return psycopg2.pool.SimpleConnectionPool(
@ -167,3 +172,27 @@ class PGVector(BaseVector):
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
class PGVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
config = current_app.config
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=config.get("PGVECTOR_HOST"),
port=config.get("PGVECTOR_PORT"),
user=config.get("PGVECTOR_USER"),
password=config.get("PGVECTOR_PASSWORD"),
database=config.get("PGVECTOR_DATABASE"),
),
)

View File

@ -1,3 +1,4 @@
import json
import os import os
import uuid import uuid
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Sequence
@ -5,6 +6,7 @@ from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client import qdrant_client
from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel
from qdrant_client.http import models as rest from qdrant_client.http import models as rest
from qdrant_client.http.models import ( from qdrant_client.http.models import (
@ -17,10 +19,15 @@ from qdrant_client.http.models import (
) )
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING: if TYPE_CHECKING:
from qdrant_client import grpc # noqa from qdrant_client import grpc # noqa
@ -69,7 +76,7 @@ class QdrantVector(BaseVector):
self._group_id = group_id self._group_id = group_id
def get_type(self) -> str: def get_type(self) -> str:
return 'qdrant' return VectorType.QDRANT
def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
@ -408,3 +415,40 @@ class QdrantVector(BaseVector):
page_content=scored_point.payload.get(content_payload_key), page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {}, metadata=scored_point.payload.get(metadata_payload_key) or {},
) )
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not dataset.index_struct_dict:
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
config = current_app.config
return QdrantVector(
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=config.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
)
)

View File

@ -1,12 +1,19 @@
import json
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from models.dataset import Dataset
try: try:
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
except ImportError: except ImportError:
@ -53,7 +60,7 @@ class RelytVector(BaseVector):
self._group_id = group_id self._group_id = group_id
def get_type(self) -> str: def get_type(self) -> str:
return 'relyt' return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {} index_params = {}
@ -240,10 +247,10 @@ class RelytVector(BaseVector):
return docs return docs
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: list[float], embedding: list[float],
k: int = 4, k: int = 4,
filter: Optional[dict] = None, filter: Optional[dict] = None,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
# Add the filter if provided # Add the filter if provided
try: try:
@ -298,3 +305,28 @@ class RelytVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz/relyt doesn't support bm25 search # milvus/zilliz/relyt doesn't support bm25 search
return [] return []
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
config = current_app.config
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=config.get('RELYT_HOST'),
port=config.get('RELYT_PORT'),
user=config.get('RELYT_USER'),
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
),
group_id=dataset.id
)

View File

@ -3,14 +3,19 @@ import logging
from typing import Any from typing import Any
import sqlalchemy import sqlalchemy
from flask import current_app
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base from sqlalchemy.orm import Session, declarative_base
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,6 +44,9 @@ class TiDBVectorConfig(BaseModel):
class TiDBVector(BaseVector): class TiDBVector(BaseVector):
def get_type(self) -> str:
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table: def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType from tidb_vector.sqlalchemy import VectorType
return Table( return Table(
@ -214,3 +222,28 @@ class TiDBVector(BaseVector):
with Session(self._engine) as session: with Session(self._engine) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit() session.commit()
class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
config = current_app.config
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=config.get('TIDB_VECTOR_HOST'),
port=config.get('TIDB_VECTOR_PORT'),
user=config.get('TIDB_VECTOR_USER'),
password=config.get('TIDB_VECTOR_PASSWORD'),
database=config.get('TIDB_VECTOR_DATABASE'),
),
)

View File

@ -11,6 +11,10 @@ class BaseVector(ABC):
def __init__(self, collection_name: str): def __init__(self, collection_name: str):
self._collection_name = collection_name self._collection_name = collection_name
@abstractmethod
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod @abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError raise NotImplementedError

View File

@ -1,4 +1,4 @@
import json from abc import ABC, abstractmethod
from typing import Any from typing import Any
from flask import current_app from flask import current_app
@ -8,9 +8,23 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from models.dataset import Dataset
from models.dataset import Dataset, DatasetCollectionBinding
class AbstractVectorFactory(ABC):
@abstractmethod
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
index_struct_dict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name}
}
return index_struct_dict
class Vector: class Vector:
@ -32,188 +46,35 @@ class Vector:
if not vector_type: if not vector_type:
raise ValueError("Vector store must be specified.") raise ValueError("Vector store must be specified.")
if vector_type == "weaviate": vector_factory_cls = self.get_vector_factory(vector_type)
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
attributes=self._attributes
)
elif vector_type == "qdrant":
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
if self._dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
if not self._dataset.index_struct_dict: @staticmethod
index_struct_dict = { def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
"type": 'qdrant', match vector_type:
"vector_store": {"class_prefix": collection_name} case VectorType.MILVUS:
} from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
self._dataset.index_struct = json.dumps(index_struct_dict) return MilvusVectorFactory
case VectorType.PGVECTOR:
return QdrantVector( from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
collection_name=collection_name, return PGVectorFactory
group_id=self._dataset.id, case VectorType.PGVECTO_RS:
config=QdrantConfig( from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
endpoint=config.get('QDRANT_URL'), return PGVectoRSFactory
api_key=config.get('QDRANT_API_KEY'), case VectorType.QDRANT:
root_path=current_app.root_path, from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
timeout=config.get('QDRANT_CLIENT_TIMEOUT'), return QdrantVectorFactory
grpc_port=config.get('QDRANT_GRPC_PORT'), case VectorType.RELYT:
prefer_grpc=config.get('QDRANT_GRPC_ENABLED') from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
) return RelytVectorFactory
) case VectorType.TIDB_VECTOR:
elif vector_type == "milvus": from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector return TiDBVectorFactory
if self._dataset.index_struct_dict: case VectorType.WEAVIATE:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
collection_name = class_prefix return WeaviateVectorFactory
else: case _:
dataset_id = self._dataset.id raise ValueError(f"Vector store {vector_type} is not supported.")
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
)
)
elif vector_type == "relyt":
from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'relyt',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=config.get('RELYT_HOST'),
port=config.get('RELYT_PORT'),
user=config.get('RELYT_USER'),
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
),
group_id=self._dataset.id
)
elif vector_type == "pgvecto_rs":
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": 'pgvecto_rs',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
dim = len(self._embeddings.embed_query("pgvecto_rs"))
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
),
dim=dim
)
elif vector_type == "pgvector":
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": "pgvector",
"vector_store": {"class_prefix": collection_name}}
self._dataset.index_struct = json.dumps(index_struct_dict)
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=config.get("PGVECTOR_HOST"),
port=config.get("PGVECTOR_PORT"),
user=config.get("PGVECTOR_USER"),
password=config.get("PGVECTOR_PASSWORD"),
database=config.get("PGVECTOR_DATABASE"),
),
)
elif vector_type == "tidb_vector":
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = self._dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict = {
"type": 'tidb_vector',
"vector_store": {"class_prefix": collection_name}
}
self._dataset.index_struct = json.dumps(index_struct_dict)
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=config.get('TIDB_VECTOR_HOST'),
port=config.get('TIDB_VECTOR_PORT'),
user=config.get('TIDB_VECTOR_USER'),
password=config.get('TIDB_VECTOR_PASSWORD'),
database=config.get('TIDB_VECTOR_DATABASE'),
),
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list = None, **kwargs): def create(self, texts: list = None, **kwargs):
if texts: if texts:

View File

@ -0,0 +1,11 @@
from enum import Enum
class VectorType(str, Enum):
MILVUS = 'milvus'
PGVECTOR = 'pgvector'
PGVECTO_RS = 'pgvecto-rs'
QDRANT = 'qdrant'
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'

View File

@ -1,12 +1,17 @@
import datetime import datetime
import json
from typing import Any, Optional from typing import Any, Optional
import requests import requests
import weaviate import weaviate
from flask import current_app
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
@ -59,7 +64,7 @@ class WeaviateVector(BaseVector):
return client return client
def get_type(self) -> str: def get_type(self) -> str:
return 'weaviate' return VectorType.WEAVIATE
def get_collection_name(self, dataset: Dataset) -> str: def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict: if dataset.index_struct_dict:
@ -255,3 +260,25 @@ class WeaviateVector(BaseVector):
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
return value.isoformat() return value.isoformat()
return value return value
class WeaviateVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
api_key=current_app.config.get('WEAVIATE_API_KEY'),
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
),
attributes=attributes
)