From 3e9dbe3e0a9c1d0d2fce347e1e79557d2eb23c8c Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 29 Apr 2024 11:58:17 +0800 Subject: [PATCH] add pgvecto_rs support and upgrade SQLAlchemy (#3833) --- .github/workflows/api-tests.yml | 4 +- api/.env.example | 9 +- api/config.py | 7 + api/controllers/console/datasets/datasets.py | 2 +- .../rag/datasource/vdb/pgvecto_rs/__init__.py | 0 .../datasource/vdb/pgvecto_rs/collection.py | 12 + .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 224 ++++++++++++++++++ .../rag/datasource/vdb/relyt/relyt_vector.py | 2 +- api/core/rag/datasource/vdb/vector_factory.py | 25 ++ api/migrations/script.py.mako | 1 + api/models/__init__.py | 27 +++ api/models/account.py | 22 +- api/models/api_based_extension.py | 7 +- api/models/dataset.py | 73 +++--- api/models/model.py | 178 +++++++------- api/models/provider.py | 25 +- api/models/source.py | 7 +- api/models/tool.py | 7 +- api/models/tools.py | 42 ++-- api/models/web.py | 18 +- api/models/workflow.py | 47 ++-- api/requirements.txt | 2 +- .../vdb/pgvecto_rs/__init__.py | 0 .../vdb/pgvecto_rs/test_pgvecto_rs.py | 37 +++ .../vdb/test_vector_store.py | 2 +- docker/docker-compose.pgvecto-rs.yaml | 24 ++ 26 files changed, 584 insertions(+), 220 deletions(-) create mode 100644 api/core/rag/datasource/vdb/pgvecto_rs/__init__.py create mode 100644 api/core/rag/datasource/vdb/pgvecto_rs/collection.py create mode 100644 api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py create mode 100644 api/tests/integration_tests/vdb/pgvecto_rs/__init__.py create mode 100644 api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py create mode 100644 docker/docker-compose.pgvecto-rs.yaml diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 6bdc0e726d..60624eb570 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -61,19 +61,21 @@ jobs: - name: Run Workflow run: dev/pytest/pytest_workflow.sh - - name: Set up Vector Stores (Weaviate, Qdrant and Milvus) + - name: Set up Vector Stores (Weaviate, Qdrant, Milvus, PgVecto-RS) uses: hoverkraft-tech/compose-action@v2.0.0 with: compose-file: | docker/docker-compose.middleware.yaml docker/docker-compose.qdrant.yaml docker/docker-compose.milvus.yaml + docker/docker-compose.pgvecto-rs.yaml services: | weaviate qdrant etcd minio milvus-standalone + pgvecto-rs - name: Test Vector Stores run: dev/pytest/pytest_vdb.sh diff --git a/api/.env.example b/api/.env.example index 1e348d77ba..d60e9947dd 100644 --- a/api/.env.example +++ b/api/.env.example @@ -62,7 +62,7 @@ GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-stri WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, relyt +# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs VECTOR_STORE=weaviate # Weaviate configuration @@ -92,6 +92,13 @@ RELYT_USER=postgres RELYT_PASSWORD=postgres RELYT_DATABASE=postgres +# PGVECTO_RS configuration +PGVECTO_RS_HOST=localhost +PGVECTO_RS_PORT=5431 +PGVECTO_RS_USER=postgres +PGVECTO_RS_PASSWORD=difyai123456 +PGVECTO_RS_DATABASE=postgres + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/config.py b/api/config.py index 12a57e9371..de4d043b26 100644 --- a/api/config.py +++ b/api/config.py @@ -251,6 +251,13 @@ class Config: self.RELYT_PASSWORD = get_env('RELYT_PASSWORD') self.RELYT_DATABASE = get_env('RELYT_DATABASE') + # pgvecto rs settings + self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST') + self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT') + self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER') + self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD') + self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE') + # ------------------------ # Mail Configurations. # ------------------------ diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f98ab419c5..40ded54120 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -476,7 +476,7 @@ class DatasetRetrievalSettingApi(Resource): @account_initialization_required def get(self): vector_type = current_app.config['VECTOR_STORE'] - if vector_type == 'milvus' or vector_type == 'relyt': + if vector_type == 'milvus' or vector_type == 'pgvecto_rs' or vector_type == 'relyt': return { 'retrieval_method': [ 'semantic_search' diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/__init__.py b/api/core/rag/datasource/vdb/pgvecto_rs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/collection.py b/api/core/rag/datasource/vdb/pgvecto_rs/collection.py new file mode 100644 index 0000000000..c335bc610d --- /dev/null +++ b/api/core/rag/datasource/vdb/pgvecto_rs/collection.py @@ -0,0 +1,12 @@ +from uuid import UUID + +from numpy import ndarray +from sqlalchemy.orm import DeclarativeBase, Mapped + + +class CollectionORM(DeclarativeBase): + __tablename__: str + id: Mapped[UUID] + text: Mapped[str] + meta: Mapped[dict] + vector: Mapped[ndarray] diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py new file mode 100644 index 0000000000..5735b79b6e --- /dev/null +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -0,0 +1,224 @@ +import logging +from typing import Any +from uuid import UUID, uuid4 + +from numpy import ndarray +from pgvecto_rs.sqlalchemy import Vector +from pydantic import BaseModel, root_validator +from sqlalchemy import Float, String, create_engine, insert, select, text +from sqlalchemy import text as sql_text +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import Mapped, Session, mapped_column + +from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class PgvectoRSConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @root_validator() + def validate_config(cls, values: dict) -> dict: + if not values['host']: + raise ValueError("config PGVECTO_RS_HOST is required") + if not values['port']: + raise ValueError("config PGVECTO_RS_PORT is required") + if not values['user']: + raise ValueError("config PGVECTO_RS_USER is required") + if not values['password']: + raise ValueError("config PGVECTO_RS_PASSWORD is required") + if not values['database']: + raise ValueError("config PGVECTO_RS_DATABASE is required") + return values + + +class PGVectoRS(BaseVector): + + def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): + super().__init__(collection_name) + self._client_config = config + self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._client = create_engine(self._url) + with Session(self._client) as session: + session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) + session.commit() + self._fields = [] + + class _Table(CollectionORM): + __tablename__ = collection_name + __table_args__ = {"extend_existing": True} # noqa: RUF012 + id: Mapped[UUID] = mapped_column( + postgresql.UUID(as_uuid=True), + primary_key=True, + ) + text: Mapped[str] = mapped_column(String) + meta: Mapped[dict] = mapped_column(postgresql.JSONB) + vector: Mapped[ndarray] = mapped_column(Vector(dim)) + + self._table = _Table + self._distance_op = "<=>" + + def get_type(self) -> str: + return 'pgvecto-rs' + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.create_collection(len(embeddings[0])) + self.add_texts(texts, embeddings) + + def create_collection(self, dimension: int): + lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + index_name = f"{self._collection_name}_embedding_index" + with Session(self._client) as session: + create_statement = sql_text(f""" + CREATE TABLE IF NOT EXISTS {self._collection_name} ( + id UUID PRIMARY KEY, + text TEXT NOT NULL, + meta JSONB NOT NULL, + vector vector({dimension}) NOT NULL + ) using heap; + """) + session.execute(create_statement) + index_statement = sql_text(f""" + CREATE INDEX IF NOT EXISTS {index_name} + ON {self._collection_name} USING vectors(vector vector_l2_ops) + WITH (options = $$ + optimizing.optimizing_threads = 30 + segment.max_growing_segment_size = 2000 + segment.max_sealed_segment_size = 30000000 + [indexing.hnsw] + m=30 + ef_construction=500 + $$); + """) + session.execute(index_statement) + session.commit() + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + pks = [] + with Session(self._client) as session: + for document, embedding in zip(documents, embeddings): + pk = uuid4() + session.execute( + insert(self._table).values( + id=pk, + text=document.page_content, + meta=document.metadata, + vector=embedding, + ), + ) + pks.append(pk) + session.commit() + + return pks + + def delete_by_document_id(self, document_id: str): + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + with Session(self._client) as session: + select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") + session.execute(select_statement, {'ids': ids}) + session.commit() + + def get_ids_by_metadata_field(self, key: str, value: str): + result = None + with Session(self._client) as session: + select_statement = sql_text( + f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; " + ) + result = session.execute(select_statement).fetchall() + if result: + return [item[0] for item in result] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + + ids = self.get_ids_by_metadata_field(key, value) + if ids: + with Session(self._client) as session: + select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") + session.execute(select_statement, {'ids': ids}) + session.commit() + + def delete_by_ids(self, ids: list[str]) -> None: + with Session(self._client) as session: + select_statement = sql_text( + f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " + ) + result = session.execute(select_statement, {'doc_ids': ids}).fetchall() + if result: + ids = [item[0] for item in result] + if ids: + with Session(self._client) as session: + select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") + session.execute(select_statement, {'ids': ids}) + session.commit() + + def delete(self) -> None: + with Session(self._client) as session: + session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}")) + session.commit() + + def text_exists(self, id: str) -> bool: + with Session(self._client) as session: + select_statement = sql_text( + f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " + ) + result = session.execute(select_statement).fetchall() + return len(result) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + with Session(self._client) as session: + stmt = ( + select( + self._table, + self._table.vector.op(self._distance_op, return_type=Float)( + query_vector, + ).label("distance"), + ) + .limit(kwargs.get('top_k', 2)) + .order_by("distance") + ) + res = session.execute(stmt) + results = [(row[0], row[1]) for row in res] + + # Organize results. + docs = [] + for record, dis in results: + metadata = record.meta + score = 1 - dis + metadata['score'] = score + score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + if score > score_threshold: + doc = Document(page_content=record.text, + metadata=metadata) + docs.append(doc) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # with Session(self._client) as session: + # select_statement = sql_text( + # f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery" + # ) + # results = session.execute(select_statement).fetchall() + # if results: + # docs = [] + # for result in results: + # doc = Document(page_content=result[0], + # metadata=result[1]) + # docs.append(doc) + # return docs + return [] diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index c7d3575352..74b91db27e 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -235,7 +235,7 @@ class RelytVector(BaseVector): docs = [] for document, score in results: score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if score > score_threshold: + if 1 - score > score_threshold: docs.append(document) return docs diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 59f86b5b5f..2405d16b1d 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -139,6 +139,31 @@ class Vector: ), group_id=self._dataset.id ) + elif vector_type == "pgvecto_rs": + from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig + if self._dataset.index_struct_dict: + class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix.lower() + else: + dataset_id = self._dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + index_struct_dict = { + "type": 'pgvecto_rs', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) + dim = len(self._embeddings.embed_query("pgvecto_rs")) + return PGVectoRS( + collection_name=collection_name, + config=PgvectoRSConfig( + host=config.get('PGVECTO_RS_HOST'), + port=config.get('PGVECTO_RS_PORT'), + user=config.get('PGVECTO_RS_USER'), + password=config.get('PGVECTO_RS_PASSWORD'), + database=config.get('PGVECTO_RS_DATABASE'), + ), + dim=dim + ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/migrations/script.py.mako b/api/migrations/script.py.mako index 2c0156303a..728ccc6a9a 100644 --- a/api/migrations/script.py.mako +++ b/api/migrations/script.py.mako @@ -6,6 +6,7 @@ Create Date: ${create_date} """ from alembic import op +import models as models import sqlalchemy as sa ${imports if imports else ""} diff --git a/api/models/__init__.py b/api/models/__init__.py index 47eec53542..3b832cd22d 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,5 +1,8 @@ from enum import Enum +from sqlalchemy import CHAR, TypeDecorator +from sqlalchemy.dialects.postgresql import UUID + class CreatedByRole(Enum): """ @@ -42,3 +45,27 @@ class CreatedFrom(Enum): if role.value == value: return role raise ValueError(f'invalid createdFrom value {value}') + + +class StringUUID(TypeDecorator): + impl = CHAR + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == 'postgresql': + return str(value) + else: + return value.hex + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(36)) + + def process_result_value(self, value, dialect): + if value is None: + return value + return str(value) diff --git a/api/models/account.py b/api/models/account.py index d8e587c90c..3d5e955732 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,9 +2,9 @@ import enum import json from flask_login import UserMixin -from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db +from models import StringUUID class AccountStatus(str, enum.Enum): @@ -22,7 +22,7 @@ class Account(UserMixin, db.Model): db.Index('account_email_idx', 'email') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -128,7 +128,7 @@ class Tenant(db.Model): db.PrimaryKeyConstraint('id', name='tenant_pkey'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) @@ -168,12 +168,12 @@ class TenantAccountJoin(db.Model): db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - account_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) role = db.Column(db.String(16), nullable=False, server_default='normal') - invited_by = db.Column(UUID, nullable=True) + invited_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -186,8 +186,8 @@ class AccountIntegrate(db.Model): db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - account_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) @@ -208,7 +208,7 @@ class InvitationCode(db.Model): code = db.Column(db.String(32), nullable=False) status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) used_at = db.Column(db.DateTime) - used_by_tenant_id = db.Column(UUID) - used_by_account_id = db.Column(UUID) + used_by_tenant_id = db.Column(StringUUID) + used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index e34cfb8f7b..d1f9cd78a7 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,8 +1,7 @@ import enum -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db +from models import StringUUID class APIBasedExtensionPoint(enum.Enum): @@ -19,8 +18,8 @@ class APIBasedExtension(db.Model): db.Index('api_based_extension_tenant_idx', 'tenant_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) diff --git a/api/models/dataset.py b/api/models/dataset.py index 0e85008615..01b068fa2a 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -4,10 +4,11 @@ import pickle from json import JSONDecodeError from sqlalchemy import func -from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db from extensions.ext_storage import storage +from models import StringUUID from models.account import Account from models.model import App, Tag, TagBinding, UploadFile @@ -22,8 +23,8 @@ class Dataset(db.Model): INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) provider = db.Column(db.String(255), nullable=False, @@ -33,15 +34,15 @@ class Dataset(db.Model): data_source_type = db.Column(db.String(255)) indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(UUID, nullable=True) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) - collection_binding_id = db.Column(UUID, nullable=True) + collection_binding_id = db.Column(StringUUID, nullable=True) retrieval_model = db.Column(JSONB, nullable=True) @property @@ -145,13 +146,13 @@ class DatasetProcessRule(db.Model): db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), ) - id = db.Column(UUID, nullable=False, + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - dataset_id = db.Column(UUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -197,19 +198,19 @@ class Document(db.Model): ) # initial fields - id = db.Column(UUID, nullable=False, + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - dataset_id = db.Column(UUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) data_source_type = db.Column(db.String(255), nullable=False) data_source_info = db.Column(db.Text, nullable=True) - dataset_process_rule_id = db.Column(UUID, nullable=True) + dataset_process_rule_id = db.Column(StringUUID, nullable=True) batch = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) created_from = db.Column(db.String(255), nullable=False) - created_by = db.Column(UUID, nullable=False) - created_api_request_id = db.Column(UUID, nullable=True) + created_by = db.Column(StringUUID, nullable=False) + created_api_request_id = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -234,7 +235,7 @@ class Document(db.Model): # pause is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) - paused_by = db.Column(UUID, nullable=True) + paused_by = db.Column(StringUUID, nullable=True) paused_at = db.Column(db.DateTime, nullable=True) # error @@ -247,11 +248,11 @@ class Document(db.Model): enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) disabled_at = db.Column(db.DateTime, nullable=True) - disabled_by = db.Column(UUID, nullable=True) + disabled_by = db.Column(StringUUID, nullable=True) archived = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) archived_reason = db.Column(db.String(255), nullable=True) - archived_by = db.Column(UUID, nullable=True) + archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -356,11 +357,11 @@ class DocumentSegment(db.Model): ) # initial fields - id = db.Column(UUID, nullable=False, + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - dataset_id = db.Column(UUID, nullable=False) - document_id = db.Column(UUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) content = db.Column(db.Text, nullable=False) answer = db.Column(db.Text, nullable=True) @@ -377,13 +378,13 @@ class DocumentSegment(db.Model): enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) disabled_at = db.Column(db.DateTime, nullable=True) - disabled_by = db.Column(UUID, nullable=True) + disabled_by = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(UUID, nullable=True) + updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) indexing_at = db.Column(db.DateTime, nullable=True) @@ -421,9 +422,9 @@ class AppDatasetJoin(db.Model): db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), ) - id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - dataset_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property @@ -438,13 +439,13 @@ class DatasetQuery(db.Model): db.Index('dataset_query_dataset_id_idx', 'dataset_id'), ) - id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) - dataset_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) source = db.Column(db.String(255), nullable=False) - source_app_id = db.Column(UUID, nullable=True) + source_app_id = db.Column(StringUUID, nullable=True) created_by_role = db.Column(db.String, nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -455,8 +456,8 @@ class DatasetKeywordTable(db.Model): db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), ) - id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - dataset_id = db.Column(UUID, nullable=False, unique=True) + id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + dataset_id = db.Column(StringUUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) data_source_type = db.Column(db.String(255), nullable=False, server_default=db.text("'database'::character varying")) @@ -501,7 +502,7 @@ class Embedding(db.Model): db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx') ) - id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) model_name = db.Column(db.String(40), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")) hash = db.Column(db.String(64), nullable=False) @@ -525,7 +526,7 @@ class DatasetCollectionBinding(db.Model): ) - id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) diff --git a/api/models/model.py b/api/models/model.py index 9d5a492277..59b88eb3b1 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,13 +7,13 @@ from typing import Optional from flask import current_app, request from flask_login import UserMixin from sqlalchemy import Float, text -from sqlalchemy.dialects.postgresql import UUID from core.file.tool_file_parser import ToolFileParser from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from libs.helper import generate_string +from . import StringUUID from .account import Account, Tenant @@ -56,15 +56,15 @@ class App(db.Model): db.Index('app_tenant_id_idx', 'tenant_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) - app_model_config_id = db.Column(UUID, nullable=True) - workflow_id = db.Column(UUID, nullable=True) + app_model_config_id = db.Column(StringUUID, nullable=True) + workflow_id = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) @@ -207,8 +207,8 @@ class AppModelConfig(db.Model): db.Index('app_app_id_idx', 'app_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) @@ -430,8 +430,8 @@ class RecommendedApp(db.Model): db.Index('recommended_app_is_listed_idx', 'is_listed', 'language') ) - id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) @@ -458,10 +458,10 @@ class InstalledApp(db.Model): db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - app_id = db.Column(UUID, nullable=False) - app_owner_tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + app_owner_tenant_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False, default=0) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) last_used_at = db.Column(db.DateTime, nullable=True) @@ -486,9 +486,9 @@ class Conversation(db.Model): db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - app_model_config_id = db.Column(UUID, nullable=True) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) model_id = db.Column(db.String(255), nullable=True) @@ -502,10 +502,10 @@ class Conversation(db.Model): status = db.Column(db.String(255), nullable=False) invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(UUID) - from_account_id = db.Column(UUID) + from_end_user_id = db.Column(StringUUID) + from_account_id = db.Column(StringUUID) read_at = db.Column(db.DateTime) - read_account_id = db.Column(UUID) + read_account_id = db.Column(StringUUID) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -626,12 +626,12 @@ class Message(db.Model): db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) inputs = db.Column(db.JSON) query = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) @@ -650,12 +650,12 @@ class Message(db.Model): message_metadata = db.Column(db.Text) invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(UUID) - from_account_id = db.Column(UUID) + from_end_user_id = db.Column(StringUUID) + from_account_id = db.Column(StringUUID) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - workflow_run_id = db.Column(UUID) + workflow_run_id = db.Column(StringUUID) @property def re_sign_file_url_answer(self) -> str: @@ -846,15 +846,15 @@ class MessageFeedback(db.Model): db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - conversation_id = db.Column(UUID, nullable=False) - message_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + conversation_id = db.Column(StringUUID, nullable=False) + message_id = db.Column(StringUUID, nullable=False) rating = db.Column(db.String(255), nullable=False) content = db.Column(db.Text) from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(UUID) - from_account_id = db.Column(UUID) + from_end_user_id = db.Column(StringUUID) + from_account_id = db.Column(StringUUID) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -872,15 +872,15 @@ class MessageFile(db.Model): db.Index('message_file_created_by_idx', 'created_by') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - message_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) transfer_method = db.Column(db.String(255), nullable=False) url = db.Column(db.Text, nullable=True) belongs_to = db.Column(db.String(255), nullable=True) - upload_file_id = db.Column(UUID, nullable=True) + upload_file_id = db.Column(StringUUID, nullable=True) created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -893,14 +893,14 @@ class MessageAnnotation(db.Model): db.Index('message_annotation_message_idx', 'message_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True) - message_id = db.Column(UUID, nullable=True) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) + message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - account_id = db.Column(UUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -925,15 +925,15 @@ class AppAnnotationHitHistory(db.Model): db.Index('app_annotation_hit_histories_message_idx', 'message_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - annotation_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + annotation_id = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) - account_id = db.Column(UUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) score = db.Column(Float, nullable=False, server_default=db.text('0')) - message_id = db.Column(UUID, nullable=False) + message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) annotation_content = db.Column(db.Text, nullable=False) @@ -957,13 +957,13 @@ class AppAnnotationSetting(db.Model): db.Index('app_annotation_settings_app_idx', 'app_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) - collection_binding_id = db.Column(UUID, nullable=False) - created_user_id = db.Column(UUID, nullable=False) + collection_binding_id = db.Column(StringUUID, nullable=False) + created_user_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_user_id = db.Column(UUID, nullable=False) + updated_user_id = db.Column(StringUUID, nullable=False) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @property @@ -995,9 +995,9 @@ class OperationLog(db.Model): db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - account_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -1013,9 +1013,9 @@ class EndUser(UserMixin, db.Model): db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - app_id = db.Column(UUID, nullable=True) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) @@ -1033,8 +1033,8 @@ class Site(db.Model): db.Index('site_code_idx', 'code', 'status') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) @@ -1074,9 +1074,9 @@ class ApiToken(db.Model): db.Index('api_token_tenant_idx', 'tenant_id', 'type') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=True) - tenant_id = db.Column(UUID, nullable=True) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=True) + tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) @@ -1099,8 +1099,8 @@ class UploadFile(db.Model): db.Index('upload_file_tenant_idx', 'tenant_id') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) storage_type = db.Column(db.String(255), nullable=False) key = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) @@ -1108,10 +1108,10 @@ class UploadFile(db.Model): extension = db.Column(db.String(255), nullable=False) mime_type = db.Column(db.String(255), nullable=True) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - used_by = db.Column(UUID, nullable=True) + used_by = db.Column(StringUUID, nullable=True) used_at = db.Column(db.DateTime, nullable=True) hash = db.Column(db.String(255), nullable=True) @@ -1123,9 +1123,9 @@ class ApiRequest(db.Model): db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') ) - id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - api_token_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + api_token_id = db.Column(StringUUID, nullable=False) path = db.Column(db.String(255), nullable=False) request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) @@ -1140,8 +1140,8 @@ class MessageChain(db.Model): db.Index('message_chain_message_id_idx', 'message_id') ) - id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - message_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) input = db.Column(db.Text, nullable=True) output = db.Column(db.Text, nullable=True) @@ -1156,9 +1156,9 @@ class MessageAgentThought(db.Model): db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), ) - id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - message_id = db.Column(UUID, nullable=False) - message_chain_id = db.Column(UUID, nullable=True) + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(StringUUID, nullable=False) + message_chain_id = db.Column(StringUUID, nullable=True) position = db.Column(db.Integer, nullable=False) thought = db.Column(db.Text, nullable=True) tool = db.Column(db.Text, nullable=True) @@ -1166,7 +1166,7 @@ class MessageAgentThought(db.Model): tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text")) tool_input = db.Column(db.Text, nullable=True) observation = db.Column(db.Text, nullable=True) - # plugin_id = db.Column(UUID, nullable=True) ## for future design + # plugin_id = db.Column(StringUUID, nullable=True) ## for future design tool_process_data = db.Column(db.Text, nullable=True) message = db.Column(db.Text, nullable=True) message_token = db.Column(db.Integer, nullable=True) @@ -1182,7 +1182,7 @@ class MessageAgentThought(db.Model): currency = db.Column(db.String, nullable=True) latency = db.Column(db.Float, nullable=True) created_by_role = db.Column(db.String, nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property @@ -1273,15 +1273,15 @@ class DatasetRetrieverResource(db.Model): db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), ) - id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) - message_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) - dataset_id = db.Column(UUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) dataset_name = db.Column(db.Text, nullable=False) - document_id = db.Column(UUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) document_name = db.Column(db.Text, nullable=False) data_source_type = db.Column(db.Text, nullable=False) - segment_id = db.Column(UUID, nullable=False) + segment_id = db.Column(StringUUID, nullable=False) score = db.Column(db.Float, nullable=True) content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=True) @@ -1289,7 +1289,7 @@ class DatasetRetrieverResource(db.Model): segment_position = db.Column(db.Integer, nullable=True) index_node_hash = db.Column(db.Text, nullable=True) retriever_from = db.Column(db.Text, nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1303,11 +1303,11 @@ class Tag(db.Model): TAG_TYPE_LIST = ['knowledge', 'app'] - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=True) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -1319,9 +1319,9 @@ class TagBinding(db.Model): db.Index('tag_bind_tag_id_idx', 'tag_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=True) - tag_id = db.Column(UUID, nullable=True) - target_id = db.Column(UUID, nullable=True) - created_by = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=True) + tag_id = db.Column(StringUUID, nullable=True) + target_id = db.Column(StringUUID, nullable=True) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/provider.py b/api/models/provider.py index 4c9fd793cc..413e8f9d67 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,8 +1,7 @@ from enum import Enum -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db +from models import StringUUID class ProviderType(Enum): @@ -46,8 +45,8 @@ class Provider(db.Model): db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(40), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) @@ -93,8 +92,8 @@ class ProviderModel(db.Model): db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) @@ -111,8 +110,8 @@ class TenantDefaultModel(db.Model): db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(40), nullable=False) model_type = db.Column(db.String(40), nullable=False) @@ -127,8 +126,8 @@ class TenantPreferredModelProvider(db.Model): db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(40), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -142,10 +141,10 @@ class ProviderOrder(db.Model): db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(40), nullable=False) - account_id = db.Column(UUID, nullable=False) + account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) transaction_id = db.Column(db.String(191)) diff --git a/api/models/source.py b/api/models/source.py index 8afe0f9522..97ba23a5bd 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,6 +1,7 @@ -from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects.postgresql import JSONB from extensions.ext_database import db +from models import StringUUID class DataSourceBinding(db.Model): @@ -11,8 +12,8 @@ class DataSourceBinding(db.Model): db.Index('source_info_idx', "source_info", postgresql_using='gin') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) diff --git a/api/models/tool.py b/api/models/tool.py index ac866e20a4..f322944f5f 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -1,9 +1,8 @@ import json from enum import Enum -from sqlalchemy.dialects.postgresql import UUID - from extensions.ext_database import db +from models import StringUUID class ToolProviderName(Enum): @@ -24,8 +23,8 @@ class ToolProvider(db.Model): db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) tool_name = db.Column(db.String(40), nullable=False) encrypted_credentials = db.Column(db.Text, nullable=True) is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) diff --git a/api/models/tools.py b/api/models/tools.py index 414d055780..8a133679e0 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,12 +1,12 @@ import json from sqlalchemy import ForeignKey -from sqlalchemy.dialects.postgresql import UUID from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiBasedToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType from extensions.ext_database import db +from models import StringUUID from models.model import Account, App, Tenant @@ -22,11 +22,11 @@ class BuiltinToolProvider(db.Model): ) # id of the tool provider - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) # id of the tenant - tenant_id = db.Column(UUID, nullable=True) + tenant_id = db.Column(StringUUID, nullable=True) # who created this tool provider - user_id = db.Column(UUID, nullable=False) + user_id = db.Column(StringUUID, nullable=False) # name of the tool provider provider = db.Column(db.String(40), nullable=False) # credential of the tool provider @@ -49,11 +49,11 @@ class PublishedAppTool(db.Model): ) # id of the tool provider - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) # id of the app - app_id = db.Column(UUID, ForeignKey('apps.id'), nullable=False) + app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False) # who published this tool - user_id = db.Column(UUID, nullable=False) + user_id = db.Column(StringUUID, nullable=False) # description of the tool, stored in i18n format, for human description = db.Column(db.Text, nullable=False) # llm_description of the tool, for LLM @@ -87,7 +87,7 @@ class ApiToolProvider(db.Model): db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) # name of the api provider name = db.Column(db.String(40), nullable=False) # icon @@ -96,9 +96,9 @@ class ApiToolProvider(db.Model): schema = db.Column(db.Text, nullable=False) schema_type_str = db.Column(db.String(40), nullable=False) # who created this tool - user_id = db.Column(UUID, nullable=False) + user_id = db.Column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(UUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) # description of the provider description = db.Column(db.Text, nullable=False) # json format tools @@ -140,11 +140,11 @@ class ToolModelInvoke(db.Model): db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) # who invoke this tool - user_id = db.Column(UUID, nullable=False) + user_id = db.Column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(UUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) # provider provider = db.Column(db.String(40), nullable=False) # type @@ -180,13 +180,13 @@ class ToolConversationVariables(db.Model): db.Index('conversation_id_idx', 'conversation_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) # conversation user id - user_id = db.Column(UUID, nullable=False) + user_id = db.Column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(UUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) # conversation id - conversation_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(StringUUID, nullable=False) # variables pool variables_str = db.Column(db.Text, nullable=False) @@ -208,13 +208,13 @@ class ToolFile(db.Model): db.Index('tool_file_conversation_id_idx', 'conversation_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) # conversation user id - user_id = db.Column(UUID, nullable=False) + user_id = db.Column(StringUUID, nullable=False) # tenant id - tenant_id = db.Column(UUID, nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) # conversation id - conversation_id = db.Column(UUID, nullable=True) + conversation_id = db.Column(StringUUID, nullable=True) # file key file_key = db.Column(db.String(255), nullable=False) # mime type diff --git a/api/models/web.py b/api/models/web.py index b2466430b9..6fd27206a9 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,6 +1,6 @@ -from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db +from models import StringUUID from models.model import Message @@ -11,11 +11,11 @@ class SavedMessage(db.Model): db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - message_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @property @@ -30,9 +30,9 @@ class PinnedConversation(db.Model): db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - app_id = db.Column(UUID, nullable=False) - conversation_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(StringUUID, nullable=False) + conversation_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/workflow.py b/api/models/workflow.py index f65eba3637..f261c67c77 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,10 +2,9 @@ import json from enum import Enum from typing import Optional, Union -from sqlalchemy.dialects.postgresql import UUID - from core.tools.tool_manager import ToolManager from extensions.ext_database import db +from models import StringUUID from models.account import Account @@ -102,16 +101,16 @@ class Workflow(db.Model): db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) features = db.Column(db.Text) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(UUID) + updated_by = db.Column(StringUUID) updated_at = db.Column(db.DateTime) @property @@ -245,11 +244,11 @@ class WorkflowRun(db.Model): db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - app_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) sequence_number = db.Column(db.Integer, nullable=False) - workflow_id = db.Column(UUID, nullable=False) + workflow_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) triggered_from = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False) @@ -262,7 +261,7 @@ class WorkflowRun(db.Model): total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) total_steps = db.Column(db.Integer, server_default=db.text('0')) created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) finished_at = db.Column(db.DateTime) @@ -404,12 +403,12 @@ class WorkflowNodeExecution(db.Model): 'triggered_from', 'node_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - app_id = db.Column(UUID, nullable=False) - workflow_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + workflow_id = db.Column(StringUUID, nullable=False) triggered_from = db.Column(db.String(255), nullable=False) - workflow_run_id = db.Column(UUID) + workflow_run_id = db.Column(StringUUID) index = db.Column(db.Integer, nullable=False) predecessor_node_id = db.Column(db.String(255)) node_id = db.Column(db.String(255), nullable=False) @@ -424,7 +423,7 @@ class WorkflowNodeExecution(db.Model): execution_metadata = db.Column(db.Text) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @property @@ -529,14 +528,14 @@ class WorkflowAppLog(db.Model): db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), ) - id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(UUID, nullable=False) - app_id = db.Column(UUID, nullable=False) - workflow_id = db.Column(UUID, nullable=False) - workflow_run_id = db.Column(UUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(StringUUID, nullable=False) + app_id = db.Column(StringUUID, nullable=False) + workflow_id = db.Column(StringUUID, nullable=False) + workflow_run_id = db.Column(StringUUID, nullable=False) created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(UUID, nullable=False) + created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @property diff --git a/api/requirements.txt b/api/requirements.txt index bcf248a4c4..485733ed9c 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,7 +1,7 @@ beautifulsoup4==4.12.2 flask~=3.0.1 Flask-SQLAlchemy~=3.0.5 -SQLAlchemy~=1.4.28 +SQLAlchemy~=2.0.29 Flask-Compress~=1.14 flask-login~=0.6.3 flask-migrate~=4.0.5 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py b/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 0000000000..a2b9669b90 --- /dev/null +++ b/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,37 @@ +from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig +from tests.integration_tests.vdb.test_vector_store import ( + AbstractVectorTest, + get_example_text, + setup_mock_redis, +) + + +class TestPgvectoRSVector(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = PGVectoRS( + collection_name=self.collection_name.lower(), + config=PgvectoRSConfig( + host='localhost', + port=5431, + user='postgres', + password='difyai123456', + database='dify', + ), + dim=128 + ) + + def search_by_full_text(self): + # pgvecto rs only support english text search, So it’s not open for now + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + def delete_by_document_id(self): + self.vector.delete_by_document_id(document_id=self.example_doc_id) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert len(ids) == 1 + +def test_pgvecot_rs(setup_mock_redis): + TestPgvectoRSVector().run_all_tests() diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/tests/integration_tests/vdb/test_vector_store.py index f7698c903f..3930daf484 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/tests/integration_tests/vdb/test_vector_store.py @@ -45,7 +45,7 @@ class AbstractVectorTest: def __init__(self): self.vector = None self.dataset_id = str(uuid.uuid4()) - self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test' self.example_doc_id = str(uuid.uuid4()) self.example_embedding = [1.001 * i for i in range(128)] diff --git a/docker/docker-compose.pgvecto-rs.yaml b/docker/docker-compose.pgvecto-rs.yaml new file mode 100644 index 0000000000..a083302b1e --- /dev/null +++ b/docker/docker-compose.pgvecto-rs.yaml @@ -0,0 +1,24 @@ +version: '3' +services: + # The pgvecto—rs database. + pgvecto-rs: + image: tensorchord/pgvecto-rs:pg16-v0.2.0 + restart: always + environment: + PGUSER: postgres + # The password for the default postgres user. + POSTGRES_PASSWORD: difyai123456 + # The name of the default postgres database. + POSTGRES_DB: dify + # postgres data directory + PGDATA: /var/lib/postgresql/data/pgdata + volumes: + - ./volumes/pgvectors/data:/var/lib/postgresql/data + # uncomment to expose db(postgresql) port to host + ports: + - "5431:5432" + healthcheck: + test: [ "CMD", "pg_isready" ] + interval: 1s + timeout: 3s + retries: 30