refactor(rag): switch to dify_config. (#6410)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Poorandy 2024-07-18 18:40:36 +08:00 committed by GitHub
parent 27c8deb4ec
commit c8f5dfcf17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 121 additions and 131 deletions

View File

@ -7,8 +7,8 @@ _import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
) )
from flask import current_app
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -316,17 +316,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
) )
config = current_app.config
# TODO handle optional params
return AnalyticdbVector( return AnalyticdbVector(
collection_name, collection_name,
AnalyticdbConfig( AnalyticdbConfig(
access_key_id=config.get("ANALYTICDB_KEY_ID"), access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"), access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=config.get("ANALYTICDB_REGION_ID"), region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=config.get("ANALYTICDB_INSTANCE_ID"), instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=config.get("ANALYTICDB_ACCOUNT"), account=dify_config.ANALYTICDB_ACCOUNT,
account_password=config.get("ANALYTICDB_PASSWORD"), account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=config.get("ANALYTICDB_NAMESPACE"), namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"), namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
), ),
) )

View File

@ -3,9 +3,9 @@ from typing import Any, Optional
import chromadb import chromadb
from chromadb import QueryResult, Settings from chromadb import QueryResult, Settings
from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -133,15 +133,14 @@ class ChromaVectorFactory(AbstractVectorFactory):
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)
config = current_app.config
return ChromaVector( return ChromaVector(
collection_name=collection_name, collection_name=collection_name,
config=ChromaConfig( config=ChromaConfig(
host=config.get('CHROMA_HOST'), host=dify_config.CHROMA_HOST,
port=int(config.get('CHROMA_PORT')), port=dify_config.CHROMA_PORT,
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT), tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE), database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
auth_provider=config.get('CHROMA_AUTH_PROVIDER'), auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'), auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
), ),
) )

View File

@ -3,10 +3,10 @@ import logging
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections from pymilvus import MilvusClient, MilvusException, connections
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
config = current_app.config
return MilvusVector( return MilvusVector(
collection_name=collection_name, collection_name=collection_name,
config=MilvusConfig( config=MilvusConfig(
host=config.get('MILVUS_HOST'), host=dify_config.MILVUS_HOST,
port=config.get('MILVUS_PORT'), port=dify_config.MILVUS_PORT,
user=config.get('MILVUS_USER'), user=dify_config.MILVUS_USER,
password=config.get('MILVUS_PASSWORD'), password=dify_config.MILVUS_PASSWORD,
secure=config.get('MILVUS_SECURE'), secure=dify_config.MILVUS_SECURE,
database=config.get('MILVUS_DATABASE'), database=dify_config.MILVUS_DATABASE,
) )
) )

View File

@ -5,9 +5,9 @@ from enum import Enum
from typing import Any from typing import Any
from clickhouse_connect import get_client from clickhouse_connect import get_client
from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -156,15 +156,15 @@ class MyScaleVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
config = current_app.config
return MyScaleVector( return MyScaleVector(
collection_name=collection_name, collection_name=collection_name,
config=MyScaleConfig( config=MyScaleConfig(
host=config.get("MYSCALE_HOST", "localhost"), # TODO: I think setting those values as the default config would be a better option.
port=int(config.get("MYSCALE_PORT", 8123)), host=dify_config.MYSCALE_HOST or "localhost",
user=config.get("MYSCALE_USER", "default"), port=dify_config.MYSCALE_PORT or 8123,
password=config.get("MYSCALE_PASSWORD", ""), user=dify_config.MYSCALE_USER or "default",
database=config.get("MYSCALE_DATABASE", "default"), password=dify_config.MYSCALE_PASSWORD or "",
fts_params=config.get("MYSCALE_FTS_PARAMS", ""), database=dify_config.MYSCALE_DATABASE or "default",
fts_params=dify_config.MYSCALE_FTS_PARAMS or "",
), ),
) )

View File

@ -4,11 +4,11 @@ import ssl
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4
from flask import current_app
from opensearchpy import OpenSearch, helpers from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
config = current_app.config
open_search_config = OpenSearchConfig( open_search_config = OpenSearchConfig(
host=config.get('OPENSEARCH_HOST'), host=dify_config.OPENSEARCH_HOST,
port=config.get('OPENSEARCH_PORT'), port=dify_config.OPENSEARCH_PORT,
user=config.get('OPENSEARCH_USER'), user=dify_config.OPENSEARCH_USER,
password=config.get('OPENSEARCH_PASSWORD'), password=dify_config.OPENSEARCH_PASSWORD,
secure=config.get('OPENSEARCH_SECURE'), secure=dify_config.OPENSEARCH_SECURE,
) )
return OpenSearchVector( return OpenSearchVector(

View File

@ -6,9 +6,9 @@ from typing import Any
import numpy import numpy
import oracledb import oracledb
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
config = current_app.config
return OracleVector( return OracleVector(
collection_name=collection_name, collection_name=collection_name,
config=OracleVectorConfig( config=OracleVectorConfig(
host=config.get("ORACLE_HOST"), host=dify_config.ORACLE_HOST,
port=config.get("ORACLE_PORT"), port=dify_config.ORACLE_PORT,
user=config.get("ORACLE_USER"), user=dify_config.ORACLE_USER,
password=config.get("ORACLE_PASSWORD"), password=dify_config.ORACLE_PASSWORD,
database=config.get("ORACLE_DATABASE"), database=dify_config.ORACLE_DATABASE,
), ),
) )

View File

@ -3,7 +3,6 @@ import logging
from typing import Any from typing import Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from flask import current_app
from numpy import ndarray from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs")) dim = len(embeddings.embed_query("pgvecto_rs"))
config = current_app.config
return PGVectoRS( return PGVectoRS(
collection_name=collection_name, collection_name=collection_name,
config=PgvectoRSConfig( config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'), host=dify_config.PGVECTO_RS_HOST,
port=config.get('PGVECTO_RS_PORT'), port=dify_config.PGVECTO_RS_PORT,
user=config.get('PGVECTO_RS_USER'), user=dify_config.PGVECTO_RS_USER,
password=config.get('PGVECTO_RS_PASSWORD'), password=dify_config.PGVECTO_RS_PASSWORD,
database=config.get('PGVECTO_RS_DATABASE'), database=dify_config.PGVECTO_RS_DATABASE,
), ),
dim=dim dim=dim
) )

View File

@ -5,9 +5,9 @@ from typing import Any
import psycopg2.extras import psycopg2.extras
import psycopg2.pool import psycopg2.pool
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
config = current_app.config
return PGVector( return PGVector(
collection_name=collection_name, collection_name=collection_name,
config=PGVectorConfig( config=PGVectorConfig(
host=config.get("PGVECTOR_HOST"), host=dify_config.PGVECTOR_HOST,
port=config.get("PGVECTOR_PORT"), port=dify_config.PGVECTOR_PORT,
user=config.get("PGVECTOR_USER"), user=dify_config.PGVECTOR_USER,
password=config.get("PGVECTOR_PASSWORD"), password=dify_config.PGVECTOR_PASSWORD,
database=config.get("PGVECTOR_DATABASE"), database=dify_config.PGVECTOR_DATABASE,
), ),
) )

View File

@ -19,6 +19,7 @@ from qdrant_client.http.models import (
) )
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
@ -444,11 +445,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name, collection_name=collection_name,
group_id=dataset.id, group_id=dataset.id,
config=QdrantConfig( config=QdrantConfig(
endpoint=config.get('QDRANT_URL'), endpoint=dify_config.QDRANT_URL,
api_key=config.get('QDRANT_API_KEY'), api_key=dify_config.QDRANT_API_KEY,
root_path=config.root_path, root_path=config.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT'), timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=config.get('QDRANT_GRPC_PORT'), grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=config.get('QDRANT_GRPC_ENABLED') prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
) )
) )

View File

@ -2,7 +2,6 @@ import json
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
@ -19,6 +18,7 @@ try:
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name)) self.gen_index_struct_dict(VectorType.RELYT, collection_name))
config = current_app.config
return RelytVector( return RelytVector(
collection_name=collection_name, collection_name=collection_name,
config=RelytConfig( config=RelytConfig(
host=config.get('RELYT_HOST'), host=dify_config.RELYT_HOST,
port=config.get('RELYT_PORT'), port=dify_config.RELYT_PORT,
user=config.get('RELYT_USER'), user=dify_config.RELYT_USER,
password=config.get('RELYT_PASSWORD'), password=dify_config.RELYT_PASSWORD,
database=config.get('RELYT_DATABASE'), database=dify_config.RELYT_DATABASE,
), ),
group_id=dataset.id group_id=dataset.id
) )

View File

@ -1,13 +1,13 @@
import json import json
from typing import Any, Optional from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel
from tcvectordb import VectorDBClient from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter from tcvectordb.model.document import Filter
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -212,16 +212,15 @@ class TencentVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
config = current_app.config
return TencentVector( return TencentVector(
collection_name=collection_name, collection_name=collection_name,
config=TencentConfig( config=TencentConfig(
url=config.get('TENCENT_VECTOR_DB_URL'), url=dify_config.TENCENT_VECTOR_DB_URL,
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'), api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'), timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=config.get('TENCENT_VECTOR_DB_USERNAME'), username=dify_config.TENCENT_VECTOR_DB_USERNAME,
database=config.get('TENCENT_VECTOR_DB_DATABASE'), database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=config.get('TENCENT_VECTOR_DB_SHARD'), shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'), replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
) )
) )

View File

@ -3,12 +3,12 @@ import logging
from typing import Any from typing import Any
import sqlalchemy import sqlalchemy
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
config = current_app.config
return TiDBVector( return TiDBVector(
collection_name=collection_name, collection_name=collection_name,
config=TiDBVectorConfig( config=TiDBVectorConfig(
host=config.get('TIDB_VECTOR_HOST'), host=dify_config.TIDB_VECTOR_HOST,
port=config.get('TIDB_VECTOR_PORT'), port=dify_config.TIDB_VECTOR_PORT,
user=config.get('TIDB_VECTOR_USER'), user=dify_config.TIDB_VECTOR_USER,
password=config.get('TIDB_VECTOR_PASSWORD'), password=dify_config.TIDB_VECTOR_PASSWORD,
database=config.get('TIDB_VECTOR_DATABASE'), database=dify_config.TIDB_VECTOR_DATABASE,
program_name=config.get('APPLICATION_NAME'), program_name=dify_config.APPLICATION_NAME,
), ),
) )

View File

@ -1,8 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from flask import current_app from configs import dify_config
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -37,8 +36,7 @@ class Vector:
self._vector_processor = self._init_vector() self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector: def _init_vector(self) -> BaseVector:
config = current_app.config vector_type = dify_config.VECTOR_STORE
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict: if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type'] vector_type = self._dataset.index_struct_dict['type']

View File

@ -4,9 +4,9 @@ from typing import Any, Optional
import requests import requests
import weaviate import weaviate
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
@ -281,9 +281,9 @@ class WeaviateVectorFactory(AbstractVectorFactory):
return WeaviateVector( return WeaviateVector(
collection_name=collection_name, collection_name=collection_name,
config=WeaviateConfig( config=WeaviateConfig(
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'), endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=current_app.config.get('WEAVIATE_API_KEY'), api_key=dify_config.WEAVIATE_API_KEY,
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE')) batch_size=dify_config.WEAVIATE_BATCH_SIZE
), ),
attributes=attributes attributes=attributes
) )

View File

@ -5,8 +5,8 @@ from typing import Union
from urllib.parse import unquote from urllib.parse import unquote
import requests import requests
from flask import current_app
from configs import dify_config
from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -94,9 +94,9 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path) storage.download(upload_file.key, file_path)
input_file = Path(file_path) input_file = Path(file_path)
file_extension = input_file.suffix.lower() file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE'] etl_type = dify_config.ETL_TYPE
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY'] unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == 'Unstructured': if etl_type == 'Unstructured':
if file_extension == '.xlsx' or file_extension == '.xls': if file_extension == '.xlsx' or file_extension == '.xls':
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)

View File

@ -3,8 +3,8 @@ import logging
from typing import Any, Optional from typing import Any, Optional
import requests import requests
from flask import current_app
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor):
self._notion_access_token = self._get_access_token(tenant_id, self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id) self._notion_workspace_id)
if not self._notion_access_token: if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None: if integration_token is None:
raise ValueError( raise ValueError(
"Must specify `integration_token` or set environment " "Must specify `integration_token` or set environment "

View File

@ -8,8 +8,8 @@ from urllib.parse import urlparse
import requests import requests
from docx import Document as DocxDocument from docx import Document as DocxDocument
from flask import current_app
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor):
storage.save(file_key, rel.target_part.blob) storage.save(file_key, rel.target_part.blob)
# save file to db # save file to db
config = current_app.config
upload_file = UploadFile( upload_file = UploadFile(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
storage_type=config['STORAGE_TYPE'], storage_type=dify_config.STORAGE_TYPE,
key=file_key, key=file_key,
name=file_key, name=file_key,
size=0, size=0,
@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor):
db.session.add(upload_file) db.session.add(upload_file)
db.session.commit() db.session.commit()
image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)" image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
return image_map return image_map

View File

@ -2,8 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from flask import current_app from configs import dify_config
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document from core.rag.models.document import Document
@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC):
# The user-defined segmentation rule # The user-defined segmentation rule
rules = processing_rule['rules'] rules = processing_rule['rules']
segmentation = rules["segmentation"] segmentation = rules["segmentation"]
max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH']) max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")