mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 20:15:52 +08:00
improve: generalize vector factory classes and vector type (#5033)
This commit is contained in:
parent
3b62ab564a
commit
bdad993901
@ -9,6 +9,7 @@ from werkzeug.exceptions import NotFound
|
|||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
@ -266,15 +267,15 @@ def migrate_knowledge_vector_database():
|
|||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
continue
|
continue
|
||||||
collection_name = ''
|
collection_name = ''
|
||||||
if vector_type == "weaviate":
|
if vector_type == VectorType.WEAVIATE:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
"type": 'weaviate',
|
"type": VectorType.WEAVIATE,
|
||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name}
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == "qdrant":
|
elif vector_type == VectorType.QDRANT:
|
||||||
if dataset.collection_binding_id:
|
if dataset.collection_binding_id:
|
||||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||||
@ -287,20 +288,20 @@ def migrate_knowledge_vector_database():
|
|||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
"type": 'qdrant',
|
"type": VectorType.QDRANT,
|
||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name}
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
|
|
||||||
elif vector_type == "milvus":
|
elif vector_type == VectorType.MILVUS:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
"type": 'milvus',
|
"type": VectorType.MILVUS,
|
||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name}
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == "relyt":
|
elif vector_type == VectorType.RELYT:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
@ -308,16 +309,16 @@ def migrate_knowledge_vector_database():
|
|||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name}
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
elif vector_type == "pgvector":
|
elif vector_type == VectorType.PGVECTOR:
|
||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
index_struct_dict = {
|
index_struct_dict = {
|
||||||
"type": 'pgvector',
|
"type": VectorType.PGVECTOR,
|
||||||
"vector_store": {"class_prefix": collection_name}
|
"vector_store": {"class_prefix": collection_name}
|
||||||
}
|
}
|
||||||
dataset.index_struct = json.dumps(index_struct_dict)
|
dataset.index_struct = json.dumps(index_struct_dict)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
|
|
||||||
vector = Vector(dataset)
|
vector = Vector(dataset)
|
||||||
click.echo(f"Start to migrate dataset {dataset.id}.")
|
click.echo(f"Start to migrate dataset {dataset.id}.")
|
||||||
|
@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
|||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import related_app_list
|
from fields.app_fields import related_app_list
|
||||||
@ -476,20 +477,22 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
vector_type = current_app.config['VECTOR_STORE']
|
vector_type = current_app.config['VECTOR_STORE']
|
||||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}:
|
|
||||||
return {
|
match vector_type:
|
||||||
'retrieval_method': [
|
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
|
||||||
'semantic_search'
|
return {
|
||||||
]
|
'retrieval_method': [
|
||||||
}
|
'semantic_search'
|
||||||
elif vector_type in {"qdrant", "weaviate"}:
|
]
|
||||||
return {
|
}
|
||||||
'retrieval_method': [
|
case VectorType.QDRANT | VectorType.WEAVIATE:
|
||||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
return {
|
||||||
]
|
'retrieval_method': [
|
||||||
}
|
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||||
else:
|
]
|
||||||
raise ValueError("Unsupported vector db type.")
|
}
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrievalSettingMockApi(Resource):
|
class DatasetRetrievalSettingMockApi(Resource):
|
||||||
@ -497,20 +500,22 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, vector_type):
|
def get(self, vector_type):
|
||||||
if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}:
|
match vector_type:
|
||||||
return {
|
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR:
|
||||||
'retrieval_method': [
|
return {
|
||||||
'semantic_search'
|
'retrieval_method': [
|
||||||
]
|
'semantic_search'
|
||||||
}
|
]
|
||||||
elif vector_type in {'qdrant', 'weaviate'}:
|
}
|
||||||
return {
|
case VectorType.QDRANT | VectorType.WEAVIATE:
|
||||||
'retrieval_method': [
|
return {
|
||||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
'retrieval_method': [
|
||||||
]
|
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||||
}
|
]
|
||||||
else:
|
}
|
||||||
raise ValueError("Unsupported vector db type.")
|
case _:
|
||||||
|
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||||
|
|
||||||
|
|
||||||
class DatasetErrorDocs(Resource):
|
class DatasetErrorDocs(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
|
@ -1,14 +1,20 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
from pymilvus import MilvusClient, MilvusException, connections
|
from pymilvus import MilvusClient, MilvusException, connections
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -55,7 +61,7 @@ class MilvusVector(BaseVector):
|
|||||||
self._fields = []
|
self._fields = []
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'milvus'
|
return VectorType.MILVUS
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
index_params = {
|
index_params = {
|
||||||
@ -254,10 +260,36 @@ class MilvusVector(BaseVector):
|
|||||||
schema=schema, index_param=index_params,
|
schema=schema, index_param=index_params,
|
||||||
consistency_level=self._consistency_level)
|
consistency_level=self._consistency_level)
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
def _init_client(self, config) -> MilvusClient:
|
def _init_client(self, config) -> MilvusClient:
|
||||||
if config.secure:
|
if config.secure:
|
||||||
uri = "https://" + str(config.host) + ":" + str(config.port)
|
uri = "https://" + str(config.host) + ":" + str(config.port)
|
||||||
else:
|
else:
|
||||||
uri = "http://" + str(config.host) + ":" + str(config.port)
|
uri = "http://" + str(config.host) + ":" + str(config.port)
|
||||||
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
|
client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return MilvusVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=MilvusConfig(
|
||||||
|
host=config.get('MILVUS_HOST'),
|
||||||
|
port=config.get('MILVUS_PORT'),
|
||||||
|
user=config.get('MILVUS_USER'),
|
||||||
|
password=config.get('MILVUS_PASSWORD'),
|
||||||
|
secure=config.get('MILVUS_SECURE'),
|
||||||
|
database=config.get('MILVUS_DATABASE'),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from pgvecto_rs.sqlalchemy import Vector
|
from pgvecto_rs.sqlalchemy import Vector
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
@ -10,10 +12,14 @@ from sqlalchemy import text as sql_text
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -67,7 +73,7 @@ class PGVectoRS(BaseVector):
|
|||||||
self._distance_op = "<=>"
|
self._distance_op = "<=>"
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'pgvecto-rs'
|
return VectorType.PGVECTO_RS
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
self.create_collection(len(embeddings[0]))
|
self.create_collection(len(embeddings[0]))
|
||||||
@ -222,3 +228,28 @@ class PGVectoRS(BaseVector):
|
|||||||
# docs.append(doc)
|
# docs.append(doc)
|
||||||
# return docs
|
# return docs
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectoRSFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix.lower()
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||||
|
dim = len(embeddings.embed_query("pgvecto_rs"))
|
||||||
|
config = current_app.config
|
||||||
|
return PGVectoRS(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=PgvectoRSConfig(
|
||||||
|
host=config.get('PGVECTO_RS_HOST'),
|
||||||
|
port=config.get('PGVECTO_RS_PORT'),
|
||||||
|
user=config.get('PGVECTO_RS_USER'),
|
||||||
|
password=config.get('PGVECTO_RS_PASSWORD'),
|
||||||
|
database=config.get('PGVECTO_RS_DATABASE'),
|
||||||
|
),
|
||||||
|
dim=dim
|
||||||
|
)
|
@ -5,11 +5,16 @@ from typing import Any
|
|||||||
|
|
||||||
import psycopg2.extras
|
import psycopg2.extras
|
||||||
import psycopg2.pool
|
import psycopg2.pool
|
||||||
|
from flask import current_app
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
class PGVectorConfig(BaseModel):
|
class PGVectorConfig(BaseModel):
|
||||||
@ -51,7 +56,7 @@ class PGVector(BaseVector):
|
|||||||
self.table_name = f"embedding_{collection_name}"
|
self.table_name = f"embedding_{collection_name}"
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return "pgvector"
|
return VectorType.PGVECTOR
|
||||||
|
|
||||||
def _create_connection_pool(self, config: PGVectorConfig):
|
def _create_connection_pool(self, config: PGVectorConfig):
|
||||||
return psycopg2.pool.SimpleConnectionPool(
|
return psycopg2.pool.SimpleConnectionPool(
|
||||||
@ -167,3 +172,27 @@ class PGVector(BaseVector):
|
|||||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||||
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return PGVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=PGVectorConfig(
|
||||||
|
host=config.get("PGVECTOR_HOST"),
|
||||||
|
port=config.get("PGVECTOR_PORT"),
|
||||||
|
user=config.get("PGVECTOR_USER"),
|
||||||
|
password=config.get("PGVECTOR_PASSWORD"),
|
||||||
|
database=config.get("PGVECTOR_DATABASE"),
|
||||||
|
),
|
||||||
|
)
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Iterable, Sequence
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
@ -5,6 +6,7 @@ from itertools import islice
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||||
|
|
||||||
import qdrant_client
|
import qdrant_client
|
||||||
|
from flask import current_app
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from qdrant_client.http import models as rest
|
from qdrant_client.http import models as rest
|
||||||
from qdrant_client.http.models import (
|
from qdrant_client.http.models import (
|
||||||
@ -17,10 +19,15 @@ from qdrant_client.http.models import (
|
|||||||
)
|
)
|
||||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset, DatasetCollectionBinding
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from qdrant_client import grpc # noqa
|
from qdrant_client import grpc # noqa
|
||||||
@ -69,7 +76,7 @@ class QdrantVector(BaseVector):
|
|||||||
self._group_id = group_id
|
self._group_id = group_id
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'qdrant'
|
return VectorType.QDRANT
|
||||||
|
|
||||||
def to_index_struct(self) -> dict:
|
def to_index_struct(self) -> dict:
|
||||||
return {
|
return {
|
||||||
@ -408,3 +415,40 @@ class QdrantVector(BaseVector):
|
|||||||
page_content=scored_point.payload.get(content_payload_key),
|
page_content=scored_point.payload.get(content_payload_key),
|
||||||
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
metadata=scored_point.payload.get(metadata_payload_key) or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||||
|
if dataset.collection_binding_id:
|
||||||
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||||
|
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||||
|
one_or_none()
|
||||||
|
if dataset_collection_binding:
|
||||||
|
collection_name = dataset_collection_binding.collection_name
|
||||||
|
else:
|
||||||
|
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||||
|
else:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
|
||||||
|
if not dataset.index_struct_dict:
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.QDRANT, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return QdrantVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
group_id=dataset.id,
|
||||||
|
config=QdrantConfig(
|
||||||
|
endpoint=config.get('QDRANT_URL'),
|
||||||
|
api_key=config.get('QDRANT_API_KEY'),
|
||||||
|
root_path=config.root_path,
|
||||||
|
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
||||||
|
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
||||||
|
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||||
from sqlalchemy import text as sql_text
|
from sqlalchemy import text as sql_text
|
||||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -53,7 +60,7 @@ class RelytVector(BaseVector):
|
|||||||
self._group_id = group_id
|
self._group_id = group_id
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'relyt'
|
return VectorType.RELYT
|
||||||
|
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
index_params = {}
|
index_params = {}
|
||||||
@ -240,10 +247,10 @@ class RelytVector(BaseVector):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
def similarity_search_with_score_by_vector(
|
def similarity_search_with_score_by_vector(
|
||||||
self,
|
self,
|
||||||
embedding: list[float],
|
embedding: list[float],
|
||||||
k: int = 4,
|
k: int = 4,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
) -> list[tuple[Document, float]]:
|
) -> list[tuple[Document, float]]:
|
||||||
# Add the filter if provided
|
# Add the filter if provided
|
||||||
try:
|
try:
|
||||||
@ -298,3 +305,28 @@ class RelytVector(BaseVector):
|
|||||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||||
# milvus/zilliz/relyt doesn't support bm25 search
|
# milvus/zilliz/relyt doesn't support bm25 search
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class RelytVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return RelytVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=RelytConfig(
|
||||||
|
host=config.get('RELYT_HOST'),
|
||||||
|
port=config.get('RELYT_PORT'),
|
||||||
|
user=config.get('RELYT_USER'),
|
||||||
|
password=config.get('RELYT_PASSWORD'),
|
||||||
|
database=config.get('RELYT_DATABASE'),
|
||||||
|
),
|
||||||
|
group_id=dataset.id
|
||||||
|
)
|
||||||
|
@ -3,14 +3,19 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
from flask import current_app
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||||
from sqlalchemy import text as sql_text
|
from sqlalchemy import text as sql_text
|
||||||
from sqlalchemy.orm import Session, declarative_base
|
from sqlalchemy.orm import Session, declarative_base
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -39,6 +44,9 @@ class TiDBVectorConfig(BaseModel):
|
|||||||
|
|
||||||
class TiDBVector(BaseVector):
|
class TiDBVector(BaseVector):
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return VectorType.TIDB_VECTOR
|
||||||
|
|
||||||
def _table(self, dim: int) -> Table:
|
def _table(self, dim: int) -> Table:
|
||||||
from tidb_vector.sqlalchemy import VectorType
|
from tidb_vector.sqlalchemy import VectorType
|
||||||
return Table(
|
return Table(
|
||||||
@ -214,3 +222,28 @@ class TiDBVector(BaseVector):
|
|||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class TiDBVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
|
||||||
|
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix.lower()
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||||
|
|
||||||
|
config = current_app.config
|
||||||
|
return TiDBVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=TiDBVectorConfig(
|
||||||
|
host=config.get('TIDB_VECTOR_HOST'),
|
||||||
|
port=config.get('TIDB_VECTOR_PORT'),
|
||||||
|
user=config.get('TIDB_VECTOR_USER'),
|
||||||
|
password=config.get('TIDB_VECTOR_PASSWORD'),
|
||||||
|
database=config.get('TIDB_VECTOR_DATABASE'),
|
||||||
|
),
|
||||||
|
)
|
@ -11,6 +11,10 @@ class BaseVector(ABC):
|
|||||||
def __init__(self, collection_name: str):
|
def __init__(self, collection_name: str):
|
||||||
self._collection_name = collection_name
|
self._collection_name = collection_name
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_type(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import json
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@ -8,9 +8,23 @@ from core.model_manager import ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.datasource.entity.embedding import Embeddings
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from models.dataset import Dataset
|
||||||
from models.dataset import Dataset, DatasetCollectionBinding
|
|
||||||
|
|
||||||
|
class AbstractVectorFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict:
|
||||||
|
index_struct_dict = {
|
||||||
|
"type": vector_type,
|
||||||
|
"vector_store": {"class_prefix": collection_name}
|
||||||
|
}
|
||||||
|
return index_struct_dict
|
||||||
|
|
||||||
|
|
||||||
class Vector:
|
class Vector:
|
||||||
@ -32,188 +46,35 @@ class Vector:
|
|||||||
if not vector_type:
|
if not vector_type:
|
||||||
raise ValueError("Vector store must be specified.")
|
raise ValueError("Vector store must be specified.")
|
||||||
|
|
||||||
if vector_type == "weaviate":
|
vector_factory_cls = self.get_vector_factory(vector_type)
|
||||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings)
|
||||||
if self._dataset.index_struct_dict:
|
|
||||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
||||||
collection_name = class_prefix
|
|
||||||
else:
|
|
||||||
dataset_id = self._dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": 'weaviate',
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
return WeaviateVector(
|
|
||||||
collection_name=collection_name,
|
|
||||||
config=WeaviateConfig(
|
|
||||||
endpoint=config.get('WEAVIATE_ENDPOINT'),
|
|
||||||
api_key=config.get('WEAVIATE_API_KEY'),
|
|
||||||
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
|
||||||
),
|
|
||||||
attributes=self._attributes
|
|
||||||
)
|
|
||||||
elif vector_type == "qdrant":
|
|
||||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
|
||||||
if self._dataset.collection_binding_id:
|
|
||||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
||||||
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
|
|
||||||
one_or_none()
|
|
||||||
if dataset_collection_binding:
|
|
||||||
collection_name = dataset_collection_binding.collection_name
|
|
||||||
else:
|
|
||||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
|
||||||
else:
|
|
||||||
if self._dataset.index_struct_dict:
|
|
||||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
||||||
collection_name = class_prefix
|
|
||||||
else:
|
|
||||||
dataset_id = self._dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
|
|
||||||
if not self._dataset.index_struct_dict:
|
@staticmethod
|
||||||
index_struct_dict = {
|
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||||
"type": 'qdrant',
|
match vector_type:
|
||||||
"vector_store": {"class_prefix": collection_name}
|
case VectorType.MILVUS:
|
||||||
}
|
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
return MilvusVectorFactory
|
||||||
|
case VectorType.PGVECTOR:
|
||||||
return QdrantVector(
|
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||||
collection_name=collection_name,
|
return PGVectorFactory
|
||||||
group_id=self._dataset.id,
|
case VectorType.PGVECTO_RS:
|
||||||
config=QdrantConfig(
|
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||||
endpoint=config.get('QDRANT_URL'),
|
return PGVectoRSFactory
|
||||||
api_key=config.get('QDRANT_API_KEY'),
|
case VectorType.QDRANT:
|
||||||
root_path=current_app.root_path,
|
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
return QdrantVectorFactory
|
||||||
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
case VectorType.RELYT:
|
||||||
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||||
)
|
return RelytVectorFactory
|
||||||
)
|
case VectorType.TIDB_VECTOR:
|
||||||
elif vector_type == "milvus":
|
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
return TiDBVectorFactory
|
||||||
if self._dataset.index_struct_dict:
|
case VectorType.WEAVIATE:
|
||||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||||
collection_name = class_prefix
|
return WeaviateVectorFactory
|
||||||
else:
|
case _:
|
||||||
dataset_id = self._dataset.id
|
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": 'milvus',
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
return MilvusVector(
|
|
||||||
collection_name=collection_name,
|
|
||||||
config=MilvusConfig(
|
|
||||||
host=config.get('MILVUS_HOST'),
|
|
||||||
port=config.get('MILVUS_PORT'),
|
|
||||||
user=config.get('MILVUS_USER'),
|
|
||||||
password=config.get('MILVUS_PASSWORD'),
|
|
||||||
secure=config.get('MILVUS_SECURE'),
|
|
||||||
database=config.get('MILVUS_DATABASE'),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif vector_type == "relyt":
|
|
||||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector
|
|
||||||
if self._dataset.index_struct_dict:
|
|
||||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
||||||
collection_name = class_prefix
|
|
||||||
else:
|
|
||||||
dataset_id = self._dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": 'relyt',
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
return RelytVector(
|
|
||||||
collection_name=collection_name,
|
|
||||||
config=RelytConfig(
|
|
||||||
host=config.get('RELYT_HOST'),
|
|
||||||
port=config.get('RELYT_PORT'),
|
|
||||||
user=config.get('RELYT_USER'),
|
|
||||||
password=config.get('RELYT_PASSWORD'),
|
|
||||||
database=config.get('RELYT_DATABASE'),
|
|
||||||
),
|
|
||||||
group_id=self._dataset.id
|
|
||||||
)
|
|
||||||
elif vector_type == "pgvecto_rs":
|
|
||||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig
|
|
||||||
if self._dataset.index_struct_dict:
|
|
||||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
||||||
collection_name = class_prefix.lower()
|
|
||||||
else:
|
|
||||||
dataset_id = self._dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": 'pgvecto_rs',
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
dim = len(self._embeddings.embed_query("pgvecto_rs"))
|
|
||||||
return PGVectoRS(
|
|
||||||
collection_name=collection_name,
|
|
||||||
config=PgvectoRSConfig(
|
|
||||||
host=config.get('PGVECTO_RS_HOST'),
|
|
||||||
port=config.get('PGVECTO_RS_PORT'),
|
|
||||||
user=config.get('PGVECTO_RS_USER'),
|
|
||||||
password=config.get('PGVECTO_RS_PASSWORD'),
|
|
||||||
database=config.get('PGVECTO_RS_DATABASE'),
|
|
||||||
),
|
|
||||||
dim=dim
|
|
||||||
)
|
|
||||||
elif vector_type == "pgvector":
|
|
||||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig
|
|
||||||
|
|
||||||
if self._dataset.index_struct_dict:
|
|
||||||
class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"]
|
|
||||||
collection_name = class_prefix
|
|
||||||
else:
|
|
||||||
dataset_id = self._dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": "pgvector",
|
|
||||||
"vector_store": {"class_prefix": collection_name}}
|
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
return PGVector(
|
|
||||||
collection_name=collection_name,
|
|
||||||
config=PGVectorConfig(
|
|
||||||
host=config.get("PGVECTOR_HOST"),
|
|
||||||
port=config.get("PGVECTOR_PORT"),
|
|
||||||
user=config.get("PGVECTOR_USER"),
|
|
||||||
password=config.get("PGVECTOR_PASSWORD"),
|
|
||||||
database=config.get("PGVECTOR_DATABASE"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif vector_type == "tidb_vector":
|
|
||||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig
|
|
||||||
|
|
||||||
if self._dataset.index_struct_dict:
|
|
||||||
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
|
|
||||||
collection_name = class_prefix.lower()
|
|
||||||
else:
|
|
||||||
dataset_id = self._dataset.id
|
|
||||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
|
||||||
index_struct_dict = {
|
|
||||||
"type": 'tidb_vector',
|
|
||||||
"vector_store": {"class_prefix": collection_name}
|
|
||||||
}
|
|
||||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
|
||||||
|
|
||||||
return TiDBVector(
|
|
||||||
collection_name=collection_name,
|
|
||||||
config=TiDBVectorConfig(
|
|
||||||
host=config.get('TIDB_VECTOR_HOST'),
|
|
||||||
port=config.get('TIDB_VECTOR_PORT'),
|
|
||||||
user=config.get('TIDB_VECTOR_USER'),
|
|
||||||
password=config.get('TIDB_VECTOR_PASSWORD'),
|
|
||||||
database=config.get('TIDB_VECTOR_DATABASE'),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
|
||||||
|
|
||||||
def create(self, texts: list = None, **kwargs):
|
def create(self, texts: list = None, **kwargs):
|
||||||
if texts:
|
if texts:
|
||||||
|
11
api/core/rag/datasource/vdb/vector_type.py
Normal file
11
api/core/rag/datasource/vdb/vector_type.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class VectorType(str, Enum):
|
||||||
|
MILVUS = 'milvus'
|
||||||
|
PGVECTOR = 'pgvector'
|
||||||
|
PGVECTO_RS = 'pgvecto-rs'
|
||||||
|
QDRANT = 'qdrant'
|
||||||
|
RELYT = 'relyt'
|
||||||
|
TIDB_VECTOR = 'tidb_vector'
|
||||||
|
WEAVIATE = 'weaviate'
|
@ -1,12 +1,17 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import weaviate
|
import weaviate
|
||||||
|
from flask import current_app
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from core.rag.datasource.entity.embedding import Embeddings
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||||
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
@ -59,7 +64,7 @@ class WeaviateVector(BaseVector):
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'weaviate'
|
return VectorType.WEAVIATE
|
||||||
|
|
||||||
def get_collection_name(self, dataset: Dataset) -> str:
|
def get_collection_name(self, dataset: Dataset) -> str:
|
||||||
if dataset.index_struct_dict:
|
if dataset.index_struct_dict:
|
||||||
@ -255,3 +260,25 @@ class WeaviateVector(BaseVector):
|
|||||||
if isinstance(value, datetime.datetime):
|
if isinstance(value, datetime.datetime):
|
||||||
return value.isoformat()
|
return value.isoformat()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class WeaviateVectorFactory(AbstractVectorFactory):
|
||||||
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector:
|
||||||
|
if dataset.index_struct_dict:
|
||||||
|
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||||
|
collection_name = class_prefix
|
||||||
|
else:
|
||||||
|
dataset_id = dataset.id
|
||||||
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||||
|
dataset.index_struct = json.dumps(
|
||||||
|
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||||
|
|
||||||
|
return WeaviateVector(
|
||||||
|
collection_name=collection_name,
|
||||||
|
config=WeaviateConfig(
|
||||||
|
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
|
||||||
|
api_key=current_app.config.get('WEAVIATE_API_KEY'),
|
||||||
|
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
|
||||||
|
),
|
||||||
|
attributes=attributes
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user