mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 17:45:54 +08:00
Merge branch 'feat/r2' into deploy/dev
This commit is contained in:
commit
3c90a7acee
@ -469,6 +469,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
class RagPipelineConfigApi(Resource):
|
class RagPipelineConfigApi(Resource):
|
||||||
"""Resource for rag pipeline configuration."""
|
"""Resource for rag pipeline configuration."""
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from collections.abc import Mapping
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||||
|
|
||||||
from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY
|
from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
@ -192,10 +192,12 @@ class ToolProviderID(GenericProviderID):
|
|||||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||||
self.plugin_name = f"{self.provider_name}_tool"
|
self.plugin_name = f"{self.provider_name}_tool"
|
||||||
|
|
||||||
|
|
||||||
class DatasourceProviderID(GenericProviderID):
|
class DatasourceProviderID(GenericProviderID):
|
||||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||||
super().__init__(value, is_hardcoded)
|
super().__init__(value, is_hardcoded)
|
||||||
|
|
||||||
|
|
||||||
class PluginDependency(BaseModel):
|
class PluginDependency(BaseModel):
|
||||||
class Type(enum.StrEnum):
|
class Type(enum.StrEnum):
|
||||||
Github = PluginInstallationSource.Github.value
|
Github = PluginInstallationSource.Github.value
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.variables.segments import ObjectSegment
|
from core.variables.segments import ObjectSegment
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
@ -11,7 +17,7 @@ from extensions.ext_database import db
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset, Document, RateLimitLog
|
from models.dataset import Dataset, Document, RateLimitLog
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
from services.dataset_service import DocumentService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from .entities import KnowledgeIndexNodeData
|
from .entities import KnowledgeIndexNodeData
|
||||||
@ -109,14 +115,52 @@ class KnowledgeIndexNode(LLMNode):
|
|||||||
if not document:
|
if not document:
|
||||||
raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.")
|
raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.")
|
||||||
|
|
||||||
DocumentService.invoke_knowledge_index(
|
retrieval_setting = node_data.retrieval_setting
|
||||||
dataset=dataset,
|
index_method = node_data.index_method
|
||||||
document=document,
|
if not dataset.indexing_technique:
|
||||||
chunks=chunks,
|
if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||||
chunk_structure=node_data.chunk_structure,
|
raise ValueError("Indexing technique is invalid")
|
||||||
index_method=node_data.index_method,
|
|
||||||
retrieval_setting=node_data.retrieval_setting,
|
dataset.indexing_technique = index_method.indexing_technique
|
||||||
)
|
if index_method.indexing_technique == "high_quality":
|
||||||
|
model_manager = ModelManager()
|
||||||
|
if (
|
||||||
|
index_method.embedding_setting.embedding_model
|
||||||
|
and index_method.embedding_setting.embedding_model_provider
|
||||||
|
):
|
||||||
|
dataset_embedding_model = index_method.embedding_setting.embedding_model
|
||||||
|
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
|
||||||
|
else:
|
||||||
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
dataset_embedding_model = embedding_model.model
|
||||||
|
dataset_embedding_model_provider = embedding_model.provider
|
||||||
|
dataset.embedding_model = dataset_embedding_model
|
||||||
|
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
dataset_embedding_model_provider, dataset_embedding_model
|
||||||
|
)
|
||||||
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
if not dataset.retrieval_model:
|
||||||
|
default_retrieval_model = {
|
||||||
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
"reranking_enable": False,
|
||||||
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset.retrieval_model = (
|
||||||
|
retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model
|
||||||
|
) # type: ignore
|
||||||
|
index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor()
|
||||||
|
index_processor.index(dataset, document, chunks)
|
||||||
|
|
||||||
|
# update document status
|
||||||
|
document.indexing_status = "completed"
|
||||||
|
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"dataset_id": dataset.id,
|
"dataset_id": dataset.id,
|
||||||
|
@ -6,7 +6,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
@ -20,9 +20,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting
|
|
||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
from events.document_event import document_was_deleted
|
from events.document_event import document_was_deleted
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -1516,60 +1514,6 @@ class DocumentService:
|
|||||||
|
|
||||||
return documents, batch
|
return documents, batch
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def invoke_knowledge_index(
|
|
||||||
dataset: Dataset,
|
|
||||||
document: Document,
|
|
||||||
chunks: list[Any],
|
|
||||||
index_method: IndexMethod,
|
|
||||||
retrieval_setting: RetrievalSetting,
|
|
||||||
chunk_structure: Literal["text_model", "hierarchical_model"],
|
|
||||||
):
|
|
||||||
if not dataset.indexing_technique:
|
|
||||||
if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
|
||||||
raise ValueError("Indexing technique is invalid")
|
|
||||||
|
|
||||||
dataset.indexing_technique = index_method.indexing_technique
|
|
||||||
if index_method.indexing_technique == "high_quality":
|
|
||||||
model_manager = ModelManager()
|
|
||||||
if (
|
|
||||||
index_method.embedding_setting.embedding_model
|
|
||||||
and index_method.embedding_setting.embedding_model_provider
|
|
||||||
):
|
|
||||||
dataset_embedding_model = index_method.embedding_setting.embedding_model
|
|
||||||
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
|
|
||||||
else:
|
|
||||||
embedding_model = model_manager.get_default_model_instance(
|
|
||||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
|
||||||
)
|
|
||||||
dataset_embedding_model = embedding_model.model
|
|
||||||
dataset_embedding_model_provider = embedding_model.provider
|
|
||||||
dataset.embedding_model = dataset_embedding_model
|
|
||||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
|
||||||
dataset_embedding_model_provider, dataset_embedding_model
|
|
||||||
)
|
|
||||||
dataset.collection_binding_id = dataset_collection_binding.id
|
|
||||||
if not dataset.retrieval_model:
|
|
||||||
default_retrieval_model = {
|
|
||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
||||||
"reranking_enable": False,
|
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
|
||||||
"top_k": 2,
|
|
||||||
"score_threshold_enabled": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset.retrieval_model = (
|
|
||||||
retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model
|
|
||||||
) # type: ignore
|
|
||||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
|
||||||
index_processor.index(dataset, document, chunks)
|
|
||||||
|
|
||||||
# update document status
|
|
||||||
document.indexing_status = "completed"
|
|
||||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||||
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user