diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index d7a5dd5dcc..123b93fcd5 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -7,8 +7,8 @@ _import_err_msg = ( "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" ) -from flask import current_app +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel): "region_id": self.region_id, "read_timeout": self.read_timeout, } - + class AnalyticdbVector(BaseVector): _instance = None _init = False @@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance - + def __init__(self, collection_name: str, config: AnalyticdbConfig): # collection_name must be updated every time self._collection_name = collection_name.lower() @@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector): raise ValueError( f"failed to create namespace {self.config.namespace}: {e}" ) - + def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException @@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector): def get_type(self) -> str: return VectorType.ANALYTICDB - + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection_if_not_exists(dimension) @@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector): ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 - + def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models ids_str = ",".join(f"'{id}'" for id in ids) @@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector): ) documents.append(doc) return documents - + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models score_threshold = ( @@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector): ) documents.append(doc) return documents - + def delete(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models request = gpdb_20160503_models.DeleteCollectionRequest( @@ -316,17 +316,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) ) - config = current_app.config + + # TODO handle optional params return AnalyticdbVector( collection_name, AnalyticdbConfig( - access_key_id=config.get("ANALYTICDB_KEY_ID"), - access_key_secret=config.get("ANALYTICDB_KEY_SECRET"), - region_id=config.get("ANALYTICDB_REGION_ID"), - instance_id=config.get("ANALYTICDB_INSTANCE_ID"), - account=config.get("ANALYTICDB_ACCOUNT"), - account_password=config.get("ANALYTICDB_PASSWORD"), - namespace=config.get("ANALYTICDB_NAMESPACE"), - namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"), + access_key_id=dify_config.ANALYTICDB_KEY_ID, + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, + region_id=dify_config.ANALYTICDB_REGION_ID, + instance_id=dify_config.ANALYTICDB_INSTANCE_ID, + account=dify_config.ANALYTICDB_ACCOUNT, + account_password=dify_config.ANALYTICDB_PASSWORD, + namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ), - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 2d4e1975ea..1d85fa78c0 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -3,9 +3,9 @@ from typing import Any, Optional import chromadb from chromadb import QueryResult, Settings -from flask import current_app from pydantic import BaseModel +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -133,15 +133,14 @@ class ChromaVectorFactory(AbstractVectorFactory): } dataset.index_struct = json.dumps(index_struct_dict) - config = current_app.config return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=config.get('CHROMA_HOST'), - port=int(config.get('CHROMA_PORT')), - tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT), - database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE), - auth_provider=config.get('CHROMA_AUTH_PROVIDER'), - auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'), + host=dify_config.CHROMA_HOST, + port=dify_config.CHROMA_PORT, + tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, + database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, + auth_provider=dify_config.CHROMA_AUTH_PROVIDER, + auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS, ), ) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 02b715d768..5f2ab7c5fc 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -3,10 +3,10 @@ import logging from typing import Any, Optional from uuid import uuid4 -from flask import current_app from pydantic import BaseModel, model_validator from pymilvus import MilvusClient, MilvusException, connections +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) - config = current_app.config return MilvusVector( collection_name=collection_name, config=MilvusConfig( - host=config.get('MILVUS_HOST'), - port=config.get('MILVUS_PORT'), - user=config.get('MILVUS_USER'), - password=config.get('MILVUS_PASSWORD'), - secure=config.get('MILVUS_SECURE'), - database=config.get('MILVUS_DATABASE'), + host=dify_config.MILVUS_HOST, + port=dify_config.MILVUS_PORT, + user=dify_config.MILVUS_USER, + password=dify_config.MILVUS_PASSWORD, + secure=dify_config.MILVUS_SECURE, + database=dify_config.MILVUS_DATABASE, ) ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 811b08818c..33ee8259c5 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -5,9 +5,9 @@ from enum import Enum from typing import Any from clickhouse_connect import get_client -from flask import current_app from pydantic import BaseModel +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -156,15 +156,15 @@ class MyScaleVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) - config = current_app.config return MyScaleVector( collection_name=collection_name, config=MyScaleConfig( - host=config.get("MYSCALE_HOST", "localhost"), - port=int(config.get("MYSCALE_PORT", 8123)), - user=config.get("MYSCALE_USER", "default"), - password=config.get("MYSCALE_PASSWORD", ""), - database=config.get("MYSCALE_DATABASE", "default"), - fts_params=config.get("MYSCALE_FTS_PARAMS", ""), + # TODO: I think setting those values as the default config would be a better option. + host=dify_config.MYSCALE_HOST or "localhost", + port=dify_config.MYSCALE_PORT or 8123, + user=dify_config.MYSCALE_USER or "default", + password=dify_config.MYSCALE_PASSWORD or "", + database=dify_config.MYSCALE_DATABASE or "default", + fts_params=dify_config.MYSCALE_FTS_PARAMS or "", ), ) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 744ff2d517..d834e8ce14 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -4,11 +4,11 @@ import ssl from typing import Any, Optional from uuid import uuid4 -from flask import current_app from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - config = current_app.config open_search_config = OpenSearchConfig( - host=config.get('OPENSEARCH_HOST'), - port=config.get('OPENSEARCH_PORT'), - user=config.get('OPENSEARCH_USER'), - password=config.get('OPENSEARCH_PASSWORD'), - secure=config.get('OPENSEARCH_SECURE'), + host=dify_config.OPENSEARCH_HOST, + port=dify_config.OPENSEARCH_PORT, + user=dify_config.OPENSEARCH_USER, + password=dify_config.OPENSEARCH_PASSWORD, + secure=dify_config.OPENSEARCH_SECURE, ) return OpenSearchVector( diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 5f7723508c..f75310205c 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -6,9 +6,9 @@ from typing import Any import numpy import oracledb -from flask import current_app from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel): SQL_CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS {table_name} ( - id varchar2(100) + id varchar2(100) ,text CLOB NOT NULL ,meta JSON ,embedding vector NOT NULL -) +) """ @@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) - config = current_app.config return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=config.get("ORACLE_HOST"), - port=config.get("ORACLE_PORT"), - user=config.get("ORACLE_USER"), - password=config.get("ORACLE_PASSWORD"), - database=config.get("ORACLE_DATABASE"), + host=dify_config.ORACLE_HOST, + port=dify_config.ORACLE_PORT, + user=dify_config.ORACLE_USER, + password=dify_config.ORACLE_PASSWORD, + database=dify_config.ORACLE_DATABASE, ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 63c8edfbc3..82bdc5d4b9 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -3,7 +3,6 @@ import logging from typing import Any from uuid import UUID, uuid4 -from flask import current_app from numpy import ndarray from pgvecto_rs.sqlalchemy import Vector from pydantic import BaseModel, model_validator @@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector @@ -93,7 +93,7 @@ class PGVectoRS(BaseVector): text TEXT NOT NULL, meta JSONB NOT NULL, vector vector({dimension}) NOT NULL - ) using heap; + ) using heap; """) session.execute(create_statement) index_statement = sql_text(f""" @@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) - config = current_app.config + return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=config.get('PGVECTO_RS_HOST'), - port=config.get('PGVECTO_RS_PORT'), - user=config.get('PGVECTO_RS_USER'), - password=config.get('PGVECTO_RS_PASSWORD'), - database=config.get('PGVECTO_RS_DATABASE'), + host=dify_config.PGVECTO_RS_HOST, + port=dify_config.PGVECTO_RS_PORT, + user=dify_config.PGVECTO_RS_USER, + password=dify_config.PGVECTO_RS_PASSWORD, + database=dify_config.PGVECTO_RS_DATABASE, ), dim=dim - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 72d0a85f8d..33ca5bc028 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -5,9 +5,9 @@ from typing import Any import psycopg2.extras import psycopg2.pool -from flask import current_app from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( text TEXT NOT NULL, meta JSONB NOT NULL, embedding vector({dimension}) NOT NULL -) using heap; +) using heap; """ @@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) - config = current_app.config return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=config.get("PGVECTOR_HOST"), - port=config.get("PGVECTOR_PORT"), - user=config.get("PGVECTOR_USER"), - password=config.get("PGVECTOR_PASSWORD"), - database=config.get("PGVECTOR_DATABASE"), + host=dify_config.PGVECTOR_HOST, + port=dify_config.PGVECTOR_PORT, + user=dify_config.PGVECTOR_USER, + password=dify_config.PGVECTOR_PASSWORD, + database=dify_config.PGVECTOR_DATABASE, ), - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index bccc3a39f6..c7c0b7f6f4 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -19,6 +19,7 @@ from qdrant_client.http.models import ( ) from qdrant_client.local.qdrant_local import QdrantLocal +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -444,11 +445,11 @@ class QdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=config.get('QDRANT_URL'), - api_key=config.get('QDRANT_API_KEY'), + endpoint=dify_config.QDRANT_URL, + api_key=dify_config.QDRANT_API_KEY, root_path=config.root_path, - timeout=config.get('QDRANT_CLIENT_TIMEOUT'), - grpc_port=config.get('QDRANT_GRPC_PORT'), - prefer_grpc=config.get('QDRANT_GRPC_ENABLED') + timeout=dify_config.QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.QDRANT_GRPC_PORT, + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED ) ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 4fe1df717a..2e0bd6f303 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -2,7 +2,6 @@ import json import uuid from typing import Any, Optional -from flask import current_app from pydantic import BaseModel, model_validator from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import text as sql_text @@ -19,6 +18,7 @@ try: except ImportError: from sqlalchemy.ext.declarative import declarative_base +from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document from extensions.ext_redis import redis_client @@ -85,7 +85,7 @@ class RelytVector(BaseVector): document TEXT NOT NULL, metadata JSON NOT NULL, embedding vector({dimension}) NOT NULL - ) using heap; + ) using heap; """) session.execute(create_statement) index_statement = sql_text(f""" @@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.RELYT, collection_name)) - config = current_app.config return RelytVector( collection_name=collection_name, config=RelytConfig( - host=config.get('RELYT_HOST'), - port=config.get('RELYT_PORT'), - user=config.get('RELYT_USER'), - password=config.get('RELYT_PASSWORD'), - database=config.get('RELYT_DATABASE'), + host=dify_config.RELYT_HOST, + port=dify_config.RELYT_PORT, + user=dify_config.RELYT_USER, + password=dify_config.RELYT_PASSWORD, + database=dify_config.RELYT_DATABASE, ), group_id=dataset.id ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3af85854d2..cdcc22aec9 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -1,13 +1,13 @@ import json from typing import Any, Optional -from flask import current_app from pydantic import BaseModel from tcvectordb import VectorDBClient from tcvectordb.model import document, enum from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -212,16 +212,15 @@ class TencentVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) - config = current_app.config return TencentVector( collection_name=collection_name, config=TencentConfig( - url=config.get('TENCENT_VECTOR_DB_URL'), - api_key=config.get('TENCENT_VECTOR_DB_API_KEY'), - timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'), - username=config.get('TENCENT_VECTOR_DB_USERNAME'), - database=config.get('TENCENT_VECTOR_DB_DATABASE'), - shard=config.get('TENCENT_VECTOR_DB_SHARD'), - replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'), + url=dify_config.TENCENT_VECTOR_DB_URL, + api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, + timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, + username=dify_config.TENCENT_VECTOR_DB_USERNAME, + database=dify_config.TENCENT_VECTOR_DB_DATABASE, + shard=dify_config.TENCENT_VECTOR_DB_SHARD, + replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, ) - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 5922db1176..d3685c0991 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -3,12 +3,12 @@ import logging from typing import Any import sqlalchemy -from flask import current_app from pydantic import BaseModel, model_validator from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory @@ -198,8 +198,8 @@ class TiDBVector(BaseVector): with Session(self._engine) as session: select_statement = sql_text( f"""SELECT meta, text, distance FROM ( - SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance - FROM {self._collection_name} + SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance + FROM {self._collection_name} ORDER BY distance LIMIT {top_k} ) t WHERE distance < {distance};""" @@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps( self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) - config = current_app.config return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=config.get('TIDB_VECTOR_HOST'), - port=config.get('TIDB_VECTOR_PORT'), - user=config.get('TIDB_VECTOR_USER'), - password=config.get('TIDB_VECTOR_PASSWORD'), - database=config.get('TIDB_VECTOR_DATABASE'), - program_name=config.get('APPLICATION_NAME'), + host=dify_config.TIDB_VECTOR_HOST, + port=dify_config.TIDB_VECTOR_PORT, + user=dify_config.TIDB_VECTOR_USER, + password=dify_config.TIDB_VECTOR_PASSWORD, + database=dify_config.TIDB_VECTOR_DATABASE, + program_name=dify_config.APPLICATION_NAME, ), - ) \ No newline at end of file + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index f8b58e1b9a..949a4b5847 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from flask import current_app - +from configs import dify_config from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -37,8 +36,7 @@ class Vector: self._vector_processor = self._init_vector() def _init_vector(self) -> BaseVector: - config = current_app.config - vector_type = config.get('VECTOR_STORE') + vector_type = dify_config.VECTOR_STORE if self._dataset.index_struct_dict: vector_type = self._dataset.index_struct_dict['type'] diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index b7c5c96a7d..378d55472f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -4,9 +4,9 @@ from typing import Any, Optional import requests import weaviate -from flask import current_app from pydantic import BaseModel, model_validator +from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -281,9 +281,9 @@ class WeaviateVectorFactory(AbstractVectorFactory): return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=current_app.config.get('WEAVIATE_ENDPOINT'), - api_key=current_app.config.get('WEAVIATE_API_KEY'), - batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE')) + endpoint=dify_config.WEAVIATE_ENDPOINT, + api_key=dify_config.WEAVIATE_API_KEY, + batch_size=dify_config.WEAVIATE_BATCH_SIZE ), attributes=attributes ) diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 909bfdc137..d01cf48fac 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -5,8 +5,8 @@ from typing import Union from urllib.parse import unquote import requests -from flask import current_app +from configs import dify_config from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -94,9 +94,9 @@ class ExtractProcessor: storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() - etl_type = current_app.config['ETL_TYPE'] - unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] - unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY'] + etl_type = dify_config.ETL_TYPE + unstructured_api_url = dify_config.UNSTRUCTURED_API_URL + unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY if etl_type == 'Unstructured': if file_extension == '.xlsx' or file_extension == '.xls': extractor = ExcelExtractor(file_path) diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 7c6101010e..9535455909 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -3,8 +3,8 @@ import logging from typing import Any, Optional import requests -from flask import current_app +from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor): self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: - integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') + integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( "Must specify `integration_token` or set environment " diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 9045966da9..ac4a56319b 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -8,8 +8,8 @@ from urllib.parse import urlparse import requests from docx import Document as DocxDocument -from flask import current_app +from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor): storage.save(file_key, rel.target_part.blob) # save file to db - config = current_app.config upload_file = UploadFile( tenant_id=self.tenant_id, - storage_type=config['STORAGE_TYPE'], + storage_type=dify_config.STORAGE_TYPE, key=file_key, name=file_key, size=0, @@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" return image_map diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index edc16c821a..33e78ce8c5 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -2,8 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional -from flask import current_app - +from configs import dify_config from core.model_manager import ModelInstance from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.models.document import Document @@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC): # The user-defined segmentation rule rules = processing_rule['rules'] segmentation = rules["segmentation"] - max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH']) + max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")