From 49d1846e63ebeeea5d0d77b41d53b54592c7947e Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 28 Apr 2025 16:19:12 +0800 Subject: [PATCH] r2 --- .../datasource/utils/message_transformer.py | 76 ++-- api/core/plugin/manager/datasource.py | 3 +- .../nodes/datasource/datasource_node.py | 122 +++--- api/core/workflow/nodes/datasource/exc.py | 12 +- api/core/workflow/nodes/enums.py | 1 + .../nodes/knowledge_index/__init__.py | 3 + .../nodes/knowledge_index/entities.py | 147 +++++++ .../workflow/nodes/knowledge_index/exc.py | 22 + .../knowledge_index/knowledge_index_node.py | 154 +++++++ .../nodes/knowledge_index/template_prompts.py | 66 +++ .../nodes/knowledge_retrieval/entities.py | 1 - api/models/dataset.py | 1 + api/services/dataset_service.py | 403 ++++++++++++++++++ 13 files changed, 902 insertions(+), 109 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_index/__init__.py create mode 100644 api/core/workflow/nodes/knowledge_index/entities.py create mode 100644 api/core/workflow/nodes/knowledge_index/exc.py create mode 100644 api/core/workflow/nodes/knowledge_index/knowledge_index_node.py create mode 100644 api/core/workflow/nodes/knowledge_index/template_prompts.py diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 6fd0c201e3..a10030d93b 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -3,58 +3,58 @@ from collections.abc import Generator from mimetypes import guess_extension from typing import Optional +from core.datasource.datasource_file_manager import DatasourceFileManager +from core.datasource.entities.datasource_entities import DatasourceInvokeMessage from core.file import File, FileTransferMethod, FileType -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) -class ToolFileMessageTransformer: +class DatasourceFileMessageTransformer: @classmethod - def transform_tool_invoke_messages( + def transform_datasource_invoke_messages( cls, - messages: Generator[ToolInvokeMessage, None, None], + messages: Generator[DatasourceInvokeMessage, None, None], user_id: str, tenant_id: str, conversation_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ - Transform tool message and handle file download + Transform datasource message and handle file download """ for message in messages: - if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: + if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}: yield message - elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance( - message.message, ToolInvokeMessage.TextMessage + elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance( + message.message, DatasourceInvokeMessage.TextMessage ): # try to download image try: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) - file = ToolFileManager.create_file_by_url( + file = DatasourceFileManager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, conversation_id=conversation_id, ) - url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" + url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=message.meta.copy() if message.meta is not None else {}, ) except Exception as e: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage( + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.TEXT, + message=DatasourceInvokeMessage.TextMessage( text=f"Failed to download image: {message.message.text}: {e}" ), meta=message.meta.copy() if message.meta is not None else {}, ) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage meta = message.meta or {} @@ -63,12 +63,12 @@ class ToolFileMessageTransformer: filename = meta.get("file_name", None) # if message is str, encode it to bytes - if not isinstance(message.message, ToolInvokeMessage.BlobMessage): + if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage): raise ValueError("unexpected message type") # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) - file = ToolFileManager.create_file_by_raw( + file = DatasourceFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, @@ -77,22 +77,22 @@ class ToolFileMessageTransformer: filename=filename, ) - url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) # check if file is image if "image" in mimetype: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BINARY_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.BINARY_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == DatasourceInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): @@ -100,15 +100,15 @@ class ToolFileMessageTransformer: assert file.related_id is not None url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) if file.type == FileType.IMAGE: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: - yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text=url), + yield DatasourceInvokeMessage( + type=DatasourceInvokeMessage.MessageType.LINK, + message=DatasourceInvokeMessage.TextMessage(text=url), meta=meta.copy() if meta is not None else {}, ) else: @@ -117,5 +117,5 @@ class ToolFileMessageTransformer: yield message @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: - return f"/files/tools/{tool_file_id}{extension or '.bin'}" + def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: + return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" diff --git a/api/core/plugin/manager/datasource.py b/api/core/plugin/manager/datasource.py index 5a6f557e4b..efb42cd259 100644 --- a/api/core/plugin/manager/datasource.py +++ b/api/core/plugin/manager/datasource.py @@ -88,13 +88,14 @@ class PluginDatasourceManager(BasePluginManager): response = self._request_with_plugin_daemon_response_stream( "POST", - f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step", + f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages", ToolInvokeMessage, data={ "user_id": user_id, "data": { "provider": datasource_provider_id.provider_name, "datasource": datasource_name, + "credentials": credentials, "datasource_parameters": datasource_parameters, }, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 1752ba36fa..8ecf66c0d6 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -5,13 +5,13 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.datasource.datasource_engine import DatasourceEngine +from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter +from core.datasource.errors import DatasourceInvokeError +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.file import File, FileTransferMethod from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.plugin import PluginInstallationManager -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult @@ -29,11 +29,7 @@ from models.workflow import WorkflowNodeExecutionStatus from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import DatasourceNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) +from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError class DatasourceNode(BaseNode[DatasourceNodeData]): @@ -60,12 +56,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): # get datasource runtime try: - from core.tools.tool_manager import ToolManager + from core.datasource.datasource_manager import DatasourceManager - tool_runtime = ToolManager.get_workflow_tool_runtime( + datasource_runtime = DatasourceManager.get_workflow_datasource_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) - except ToolNodeError as e: + except DatasourceNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -78,14 +74,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): return # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or [] parameters = self._generate_parameters( - tool_parameters=tool_parameters, + datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, ) parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, + datasource_parameters=datasource_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, for_log=True, @@ -95,9 +91,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, - tool_parameters=parameters, + message_stream = DatasourceEngine.generic_invoke( + datasource=datasource_runtime, + datasource_parameters=parameters, user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, @@ -105,28 +101,28 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): app_id=self.app_id, conversation_id=conversation_id.text if conversation_id else None, ) - except ToolNodeError as e: + except DatasourceNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to invoke datasource: {str(e)}", error_type=type(e).__name__, ) ) return try: - # convert tool messages - yield from self._transform_message(message_stream, tool_info, parameters_for_log) - except (PluginDaemonClientSideError, ToolInvokeError) as e: + # convert datasource messages + yield from self._transform_message(message_stream, datasource_info, parameters_for_log) + except (PluginDaemonClientSideError, DatasourceInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to transform tool message: {str(e)}", + metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, + error=f"Failed to transform datasource message: {str(e)}", error_type=type(e).__name__, ) ) @@ -134,9 +130,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): def _generate_parameters( self, *, - tool_parameters: Sequence[ToolParameter], + datasource_parameters: Sequence[DatasourceParameter], variable_pool: VariablePool, - node_data: ToolNodeData, + node_data: DatasourceNodeData, for_log: bool = False, ) -> dict[str, Any]: """ @@ -151,25 +147,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): Mapping[str, Any]: A dictionary containing the generated parameters. """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} + datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) + for parameter_name in node_data.datasource_parameters: + parameter = datasource_parameters_dictionary.get(parameter_name) if not parameter: result[parameter_name] = None continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + datasource_input = node_data.datasource_parameters[parameter_name] + if datasource_input.type == "variable": + variable = variable_pool.get(datasource_input.value) if variable is None: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") + raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) + elif datasource_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(datasource_input.value)) parameter_value = segment_group.log if for_log else segment_group.text else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") result[parameter_name] = parameter_value return result @@ -181,15 +177,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): def _transform_message( self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], + messages: Generator[DatasourceInvokeMessage, None, None], + datasource_info: Mapping[str, Any], parameters_for_log: dict[str, Any], ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ # transform message and handle file storage - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( messages=messages, user_id=self.user_id, tenant_id=self.tenant_id, @@ -207,11 +203,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): for message in message_stream: if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, + DatasourceInvokeMessage.MessageType.IMAGE_LINK, + DatasourceInvokeMessage.MessageType.BINARY_LINK, + DatasourceInvokeMessage.MessageType.IMAGE, }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) url = message.message.text if message.meta: @@ -238,9 +234,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == DatasourceInvokeMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] @@ -261,14 +257,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): tenant_id=self.tenant_id, ) ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == DatasourceInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) text += message.message.text yield RunStreamChunkEvent( chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + elif message.type == DatasourceInvokeMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) if self.node_type == NodeType.AGENT: msg_metadata = message.message.json_object.pop("execution_metadata", {}) agent_execution_metadata = { @@ -277,13 +273,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): if key in NodeRunMetadataKey.__members__.values() } json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == DatasourceInvokeMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -298,13 +294,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == DatasourceInvokeMessage.MessageType.FILE: assert message.meta is not None files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) + elif message.type == DatasourceInvokeMessage.MessageType.LOG: + assert isinstance(message.message, DatasourceInvokeMessage.LogMessage) if message.message.metadata: - icon = tool_info.get("icon", "") + icon = datasource_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): manager = PluginInstallationManager() @@ -366,7 +362,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): outputs={"text": text, "files": files, "json": json, **variables}, metadata={ **agent_execution_metadata, - NodeRunMetadataKey.TOOL_INFO: tool_info, + NodeRunMetadataKey.DATASOURCE_INFO: datasource_info, NodeRunMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, @@ -379,7 +375,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ToolNodeData, + node_data: DatasourceNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -389,8 +385,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): :return: """ result = {} - for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] if input.type == "mixed": assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py index 7212e8bfc0..89980e6f45 100644 --- a/api/core/workflow/nodes/datasource/exc.py +++ b/api/core/workflow/nodes/datasource/exc.py @@ -1,16 +1,16 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" +class DatasourceNodeError(ValueError): + """Base exception for datasource node errors.""" pass -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" +class DatasourceParameterError(DatasourceNodeError): + """Exception raised for errors in datasource parameters.""" pass -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" +class DatasourceFileError(DatasourceNodeError): + """Exception raised for errors related to datasource files.""" pass diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 673d0ba049..7edc73b6ba 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -7,6 +7,7 @@ class NodeType(StrEnum): ANSWER = "answer" LLM = "llm" KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + KNOWLEDGE_INDEX = "knowledge-index" IF_ELSE = "if-else" CODE = "code" TEMPLATE_TRANSFORM = "template-transform" diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py new file mode 100644 index 0000000000..01d59b87b2 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_index_node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode"] diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py new file mode 100644 index 0000000000..a87032dba6 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -0,0 +1,147 @@ +from collections.abc import Sequence +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm.entities import VisionConfig + + +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + provider: str + model: str + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class RetrievalSetting(BaseModel): + """ + Retrieval Setting. + """ + search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] + top_k: int + score_threshold: Optional[float] = 0.5 + score_threshold_enabled: bool = False + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: Optional[RerankingModelConfig] = None + weights: Optional[WeightedScoreConfig] = None + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting + +class FileInfo(BaseModel): + """ + File Info. + """ + file_id: str + +class OnlineDocumentIcon(BaseModel): + """ + Document Icon. + """ + icon_url: str + icon_type: str + icon_emoji: str + +class OnlineDocumentInfo(BaseModel): + """ + Online document info. + """ + provider: str + workspace_id: str + page_id: str + page_type: str + icon: OnlineDocumentIcon + +class WebsiteInfo(BaseModel): + """ + website import info. + """ + provider: str + url: str + +class GeneralStructureChunk(BaseModel): + """ + General Structure Chunk. + """ + general_chunk: list[str] + data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] + + +class ParentChildChunk(BaseModel): + """ + Parent Child Chunk. + """ + parent_content: str + child_content: list[str] + + +class ParentChildStructureChunk(BaseModel): + """ + Parent Child Structure Chunk. + """ + parent_child_chunks: list[ParentChildChunk] + data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] + + +class KnowledgeIndexNodeData(BaseNodeData): + """ + Knowledge index Node Data. + """ + + type: str = "knowledge-index" + dataset_id: str + index_chunk_variable_selector: list[str] + chunk_structure: Literal["general", "parent-child"] + index_method: IndexMethod + retrieval_setting: RetrievalSetting + diff --git a/api/core/workflow/nodes/knowledge_index/exc.py b/api/core/workflow/nodes/knowledge_index/exc.py new file mode 100644 index 0000000000..afdde9c0c5 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/exc.py @@ -0,0 +1,22 @@ +class KnowledgeIndexNodeError(ValueError): + """Base class for KnowledgeIndexNode errors.""" + + +class ModelNotExistError(KnowledgeIndexNodeError): + """Raised when the model does not exist.""" + + +class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError): + """Raised when the model credentials are not initialized.""" + + +class ModelNotSupportedError(KnowledgeIndexNodeError): + """Raised when the model is not supported.""" + + +class ModelQuotaExceededError(KnowledgeIndexNodeError): + """Raised when the model provider quota is exceeded.""" + + +class InvalidModelTypeError(KnowledgeIndexNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py new file mode 100644 index 0000000000..543a170fa7 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -0,0 +1,154 @@ +import json +import logging +import re +import time +from collections import defaultdict +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast + +from sqlalchemy import Integer, and_, func, or_, text +from sqlalchemy import cast as sqlalchemy_cast + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import PromptMessageRole +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.simple_prompt_transform import ModelMode +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.metadata_entities import Condition, MetadataCondition +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.variables import StringSegment +from core.variables.segments import ObjectSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.knowledge_retrieval.template_prompts import ( + METADATA_FILTER_ASSISTANT_PROMPT_1, + METADATA_FILTER_ASSISTANT_PROMPT_2, + METADATA_FILTER_COMPLETION_PROMPT, + METADATA_FILTER_SYSTEM_PROMPT, + METADATA_FILTER_USER_PROMPT_1, + METADATA_FILTER_USER_PROMPT_3, +) +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog +from models.workflow import WorkflowNodeExecutionStatus +from services.dataset_service import DatasetService +from services.feature_service import FeatureService + +from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig +from .exc import ( + InvalidModelTypeError, + KnowledgeIndexNodeError, + KnowledgeRetrievalNodeError, + ModelCredentialsNotInitializedError, + ModelNotExistError, + ModelNotSupportedError, + ModelQuotaExceededError, +) + +logger = logging.getLogger(__name__) + +default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, +} + + +class KnowledgeIndexNode(LLMNode): + _node_data_cls = KnowledgeIndexNodeData # type: ignore + _node_type = NodeType.KNOWLEDGE_INDEX + + def _run(self) -> NodeRunResult: # type: ignore + node_data = cast(KnowledgeIndexNodeData, self.node_data) + # extract variables + variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector) + if not isinstance(variable, ObjectSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not object type.", + ) + chunks = variable.value + variables = {"chunks": chunks} + if not chunks: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." + ) + # check rate limit + if self.tenant_id: + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{self.tenant_id}" + redis_client.zadd(key, {current_time: current_time}) + redis_client.zremrangebyscore(key, 0, current_time - 60000) + request_count = redis_client.zcard(key) + if request_count > knowledge_rate_limit.limit: + # add ratelimit record + rate_limit_log = RateLimitLog( + tenant_id=self.tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + db.session.add(rate_limit_log) + db.session.commit() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error="Sorry, you have reached the knowledge base request rate limit of your subscription.", + error_type="RateLimitExceeded", + ) + + # retrieve knowledge + try: + results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks) + outputs = {"result": results} + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs + ) + + except KnowledgeIndexNodeError as e: + logger.warning("Error when running knowledge index node") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + # Temporary handle all exceptions from DatasetRetrieval class here. + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + ) + + + def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any: + dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") + + DatasetService.invoke_knowledge_index( + dataset=dataset, + chunks=chunks, + index_method=node_data.index_method, + retrieval_setting=node_data.retrieval_setting, + ) + + pass diff --git a/api/core/workflow/nodes/knowledge_index/template_prompts.py b/api/core/workflow/nodes/knowledge_index/template_prompts.py new file mode 100644 index 0000000000..7abd55d798 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/template_prompts.py @@ -0,0 +1,66 @@ +METADATA_FILTER_SYSTEM_PROMPT = """ + ### Job Description', + You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value + ### Task + Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". + ### Format + The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. + ### Constraint + DO NOT include anything other than the JSON array in your response. +""" # noqa: E501 + +METADATA_FILTER_USER_PROMPT_1 = """ + { "input_text": "I want to know which company’s email address test@example.com is?", + "metadata_fields": ["filename", "email", "phone", "address"] + } +""" + +METADATA_FILTER_ASSISTANT_PROMPT_1 = """ +```json + {"metadata_map": [ + {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} + ] + } +``` +""" + +METADATA_FILTER_USER_PROMPT_2 = """ + {"input_text": "What are the movies with a score of more than 9 in 2024?", + "metadata_fields": ["name", "year", "rating", "country"]} +""" + +METADATA_FILTER_ASSISTANT_PROMPT_2 = """ +```json + {"metadata_map": [ + {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, + {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, + ]} +``` +""" + +METADATA_FILTER_USER_PROMPT_3 = """ + '{{"input_text": "{input_text}",', + '"metadata_fields": {metadata_fields}}}' +""" + +METADATA_FILTER_COMPLETION_PROMPT = """ +### Job Description +You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value +### Task +# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". +### Format +The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. +### Constraint +DO NOT include anything other than the JSON array in your response. +### Example +Here is the chat example between human and assistant, inside XML tags. + +User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} +Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} +User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} +Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} + +### User Input +{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} +### Assistant Output +""" # noqa: E501 diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index d2e5a15545..17b3308a06 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -59,7 +59,6 @@ class MultipleRetrievalConfig(BaseModel): class ModelConfig(BaseModel): """ Model Config. - """ provider: str name: str diff --git a/api/models/dataset.py b/api/models/dataset.py index a344ab2964..3c44fb4b45 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -59,6 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) + keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) collection_binding_id = db.Column(StringUUID, nullable=True) retrieval_model = db.Column(JSONB, nullable=True) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index b019cf6b63..19962d66b9 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,6 +21,7 @@ from core.plugin.entities.plugin import ModelProviderID from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -1131,6 +1132,408 @@ class DocumentService: return documents, batch @staticmethod + def save_document_with_dataset_id( + dataset: Dataset, + knowledge_config: KnowledgeConfig, + account: Account | Any, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", + ): + # check document limit + features = FeatureService.get_features(current_user.current_tenant_id) + + if features.billing.enabled: + if not knowledge_config.original_document_id: + count = 0 + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + + if features.billing.subscription.plan == "sandbox" and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) + + # if dataset is empty, update dataset data_source_type + if not dataset.data_source_type: + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore + + if not dataset.indexing_technique: + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": + model_manager = ModelManager() + if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + dataset_embedding_model = knowledge_config.embedding_model + dataset_embedding_model_provider = knowledge_config.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore + + documents = [] + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + documents.append(document) + batch = document.batch + else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + # save process rule + if not dataset_process_rule: + process_rule = knowledge_config.process_rule + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if knowledge_config.duplicate: + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore + notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore + if not notion_info_list: + raise ValueError("No notion info list found.") + exist_page_ids = [] + exist_document = {} + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info.workspace_id + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info.pages: + if page.page_id not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, + } + # Truncate page name to 255 characters to prevent DB field length errors + truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + truncated_page_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page.page_id) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore + website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." + else: + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() + + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + return documents, batch + + @staticmethod + def invoke_knowledge_index( + dataset: Dataset, + chunks: list[Any], + index_method: IndexMethod, + retrieval_setting: RetrievalSetting, + original_document_id: str | None = None, + account: Account | Any, + created_from: str = "rag-pipline", + ): + + if not dataset.indexing_technique: + if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") + + dataset.indexing_technique = index_method.indexing_technique + if index_method.indexing_technique == "high_quality": + model_manager = ModelManager() + if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider: + dataset_embedding_model = index_method.embedding_setting.embedding_model + dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + dataset_embedding_model_provider, dataset_embedding_model + ) + dataset.collection_binding_id = dataset_collection_binding.id + if not dataset.retrieval_model: + default_retrieval_model = { + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, + } + + dataset.retrieval_model = ( + retrieval_setting.model_dump() + if retrieval_setting + else default_retrieval_model + ) # type: ignore + + documents = [] + if original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + documents.append(document) + batch = document.batch + else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + for chunk in chunks: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if knowledge_config.duplicate: + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.created_from = created_from + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, # type: ignore + knowledge_config.doc_form, + knowledge_config.doc_language, + data_source_info, + created_from, + position, + account, + file_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + + db.session.commit() + + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + + return documents, batch + @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: