Merge branch 'feat/r2' into deploy/dev

This commit is contained in:
jyong 2025-05-15 16:45:09 +08:00
commit 3c90a7acee
5 changed files with 58 additions and 67 deletions

View File

@ -469,6 +469,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
class RagPipelineConfigApi(Resource): class RagPipelineConfigApi(Resource):
"""Resource for rag pipeline configuration.""" """Resource for rag pipeline configuration."""

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from pydantic import BaseModel, Field, ValidationInfo, field_serializer, field_validator, model_validator
from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig

View File

@ -192,10 +192,12 @@ class ToolProviderID(GenericProviderID):
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
self.plugin_name = f"{self.provider_name}_tool" self.plugin_name = f"{self.provider_name}_tool"
class DatasourceProviderID(GenericProviderID): class DatasourceProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None: def __init__(self, value: str, is_hardcoded: bool = False) -> None:
super().__init__(value, is_hardcoded) super().__init__(value, is_hardcoded)
class PluginDependency(BaseModel): class PluginDependency(BaseModel):
class Type(enum.StrEnum): class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value Github = PluginInstallationSource.Github.value

View File

@ -1,7 +1,13 @@
import datetime
import logging import logging
import time import time
from typing import Any, cast from typing import Any, cast
from flask_login import current_user
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables.segments import ObjectSegment from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -11,7 +17,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, RateLimitLog from models.dataset import Dataset, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from services.dataset_service import DocumentService from services.dataset_service import DatasetCollectionBindingService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .entities import KnowledgeIndexNodeData from .entities import KnowledgeIndexNodeData
@ -109,14 +115,52 @@ class KnowledgeIndexNode(LLMNode):
if not document: if not document:
raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.") raise KnowledgeIndexNodeError(f"Document {node_data.document_id} not found.")
DocumentService.invoke_knowledge_index( retrieval_setting = node_data.retrieval_setting
dataset=dataset, index_method = node_data.index_method
document=document, if not dataset.indexing_technique:
chunks=chunks, if node_data.index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
chunk_structure=node_data.chunk_structure, raise ValueError("Indexing technique is invalid")
index_method=node_data.index_method,
retrieval_setting=node_data.retrieval_setting, dataset.indexing_technique = index_method.indexing_technique
) if index_method.indexing_technique == "high_quality":
model_manager = ModelManager()
if (
index_method.embedding_setting.embedding_model
and index_method.embedding_setting.embedding_model_provider
):
dataset_embedding_model = index_method.embedding_setting.embedding_model
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model
) # type: ignore
index_processor = IndexProcessorFactory(node_data.chunk_structure).init_index_processor()
index_processor.index(dataset, document, chunks)
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
return { return {
"dataset_id": dataset.id, "dataset_id": dataset.id,

View File

@ -6,7 +6,7 @@ import random
import time import time
import uuid import uuid
from collections import Counter from collections import Counter
from typing import Any, Literal, Optional from typing import Any, Optional
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import func, select
@ -20,9 +20,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
@ -1516,60 +1514,6 @@ class DocumentService:
return documents, batch return documents, batch
@staticmethod
def invoke_knowledge_index(
dataset: Dataset,
document: Document,
chunks: list[Any],
index_method: IndexMethod,
retrieval_setting: RetrievalSetting,
chunk_structure: Literal["text_model", "hierarchical_model"],
):
if not dataset.indexing_technique:
if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = index_method.indexing_technique
if index_method.indexing_technique == "high_quality":
model_manager = ModelManager()
if (
index_method.embedding_setting.embedding_model
and index_method.embedding_setting.embedding_model_provider
):
dataset_embedding_model = index_method.embedding_setting.embedding_model
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
retrieval_setting.model_dump() if retrieval_setting else default_retrieval_model
) # type: ignore
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
index_processor.index(dataset, document, chunks)
# update document status
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit()
@staticmethod @staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel): def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size