From 49d1846e63ebeeea5d0d77b41d53b54592c7947e Mon Sep 17 00:00:00 2001
From: jyong <718720800@qq.com>
Date: Mon, 28 Apr 2025 16:19:12 +0800
Subject: [PATCH] r2
---
.../datasource/utils/message_transformer.py | 76 ++--
api/core/plugin/manager/datasource.py | 3 +-
.../nodes/datasource/datasource_node.py | 122 +++---
api/core/workflow/nodes/datasource/exc.py | 12 +-
api/core/workflow/nodes/enums.py | 1 +
.../nodes/knowledge_index/__init__.py | 3 +
.../nodes/knowledge_index/entities.py | 147 +++++++
.../workflow/nodes/knowledge_index/exc.py | 22 +
.../knowledge_index/knowledge_index_node.py | 154 +++++++
.../nodes/knowledge_index/template_prompts.py | 66 +++
.../nodes/knowledge_retrieval/entities.py | 1 -
api/models/dataset.py | 1 +
api/services/dataset_service.py | 403 ++++++++++++++++++
13 files changed, 902 insertions(+), 109 deletions(-)
create mode 100644 api/core/workflow/nodes/knowledge_index/__init__.py
create mode 100644 api/core/workflow/nodes/knowledge_index/entities.py
create mode 100644 api/core/workflow/nodes/knowledge_index/exc.py
create mode 100644 api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
create mode 100644 api/core/workflow/nodes/knowledge_index/template_prompts.py
diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py
index 6fd0c201e3..a10030d93b 100644
--- a/api/core/datasource/utils/message_transformer.py
+++ b/api/core/datasource/utils/message_transformer.py
@@ -3,58 +3,58 @@ from collections.abc import Generator
from mimetypes import guess_extension
from typing import Optional
+from core.datasource.datasource_file_manager import DatasourceFileManager
+from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
from core.file import File, FileTransferMethod, FileType
-from core.tools.entities.tool_entities import ToolInvokeMessage
-from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
-class ToolFileMessageTransformer:
+class DatasourceFileMessageTransformer:
@classmethod
- def transform_tool_invoke_messages(
+ def transform_datasource_invoke_messages(
cls,
- messages: Generator[ToolInvokeMessage, None, None],
+ messages: Generator[DatasourceInvokeMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
- ) -> Generator[ToolInvokeMessage, None, None]:
+ ) -> Generator[DatasourceInvokeMessage, None, None]:
"""
- Transform tool message and handle file download
+ Transform datasource message and handle file download
"""
for message in messages:
- if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
+ if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}:
yield message
- elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
- message.message, ToolInvokeMessage.TextMessage
+ elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance(
+ message.message, DatasourceInvokeMessage.TextMessage
):
# try to download image
try:
- assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
- file = ToolFileManager.create_file_by_url(
+ file = DatasourceFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
- url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
+ url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
- yield ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.IMAGE_LINK,
- message=ToolInvokeMessage.TextMessage(text=url),
+ yield DatasourceInvokeMessage(
+ type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
+ message=DatasourceInvokeMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
- yield ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.TEXT,
- message=ToolInvokeMessage.TextMessage(
+ yield DatasourceInvokeMessage(
+ type=DatasourceInvokeMessage.MessageType.TEXT,
+ message=DatasourceInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
- elif message.type == ToolInvokeMessage.MessageType.BLOB:
+ elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
@@ -63,12 +63,12 @@ class ToolFileMessageTransformer:
filename = meta.get("file_name", None)
# if message is str, encode it to bytes
- if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
+ if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
- file = ToolFileManager.create_file_by_raw(
+ file = DatasourceFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
@@ -77,22 +77,22 @@ class ToolFileMessageTransformer:
filename=filename,
)
- url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
+ url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
- yield ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.IMAGE_LINK,
- message=ToolInvokeMessage.TextMessage(text=url),
+ yield DatasourceInvokeMessage(
+ type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
+ message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
- yield ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.BINARY_LINK,
- message=ToolInvokeMessage.TextMessage(text=url),
+ yield DatasourceInvokeMessage(
+ type=DatasourceInvokeMessage.MessageType.BINARY_LINK,
+ message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
- elif message.type == ToolInvokeMessage.MessageType.FILE:
+ elif message.type == DatasourceInvokeMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
@@ -100,15 +100,15 @@ class ToolFileMessageTransformer:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
- yield ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.IMAGE_LINK,
- message=ToolInvokeMessage.TextMessage(text=url),
+ yield DatasourceInvokeMessage(
+ type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
+ message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
- yield ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.LINK,
- message=ToolInvokeMessage.TextMessage(text=url),
+ yield DatasourceInvokeMessage(
+ type=DatasourceInvokeMessage.MessageType.LINK,
+ message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
@@ -117,5 +117,5 @@ class ToolFileMessageTransformer:
yield message
@classmethod
- def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
- return f"/files/tools/{tool_file_id}{extension or '.bin'}"
+ def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
+ return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
diff --git a/api/core/plugin/manager/datasource.py b/api/core/plugin/manager/datasource.py
index 5a6f557e4b..efb42cd259 100644
--- a/api/core/plugin/manager/datasource.py
+++ b/api/core/plugin/manager/datasource.py
@@ -88,13 +88,14 @@ class PluginDatasourceManager(BasePluginManager):
response = self._request_with_plugin_daemon_response_stream(
"POST",
- f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step",
+ f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages",
ToolInvokeMessage,
data={
"user_id": user_id,
"data": {
"provider": datasource_provider_id.provider_name,
"datasource": datasource_name,
+
"credentials": credentials,
"datasource_parameters": datasource_parameters,
},
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index 1752ba36fa..8ecf66c0d6 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -5,13 +5,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
+from core.datasource.datasource_engine import DatasourceEngine
+from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
+from core.datasource.errors import DatasourceInvokeError
+from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
-from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
-from core.tools.errors import ToolInvokeError
-from core.tools.tool_engine import ToolEngine
-from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
@@ -29,11 +29,7 @@ from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import DatasourceNodeData
-from .exc import (
- ToolFileError,
- ToolNodeError,
- ToolParameterError,
-)
+from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError
class DatasourceNode(BaseNode[DatasourceNodeData]):
@@ -60,12 +56,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
# get datasource runtime
try:
- from core.tools.tool_manager import ToolManager
+ from core.datasource.datasource_manager import DatasourceManager
- tool_runtime = ToolManager.get_workflow_tool_runtime(
+ datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
- except ToolNodeError as e:
+ except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -78,14 +74,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return
# get parameters
- tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
+ datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
- tool_parameters=tool_parameters,
+ datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
- tool_parameters=tool_parameters,
+ datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
@@ -95,9 +91,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
- message_stream = ToolEngine.generic_invoke(
- tool=tool_runtime,
- tool_parameters=parameters,
+ message_stream = DatasourceEngine.generic_invoke(
+ datasource=datasource_runtime,
+ datasource_parameters=parameters,
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
@@ -105,28 +101,28 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
- except ToolNodeError as e:
+ except DatasourceNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
- metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
- error=f"Failed to invoke tool: {str(e)}",
+ metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
+ error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__,
)
)
return
try:
- # convert tool messages
- yield from self._transform_message(message_stream, tool_info, parameters_for_log)
- except (PluginDaemonClientSideError, ToolInvokeError) as e:
+ # convert datasource messages
+ yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
+ except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
- metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
- error=f"Failed to transform tool message: {str(e)}",
+ metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
+ error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__,
)
)
@@ -134,9 +130,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
def _generate_parameters(
self,
*,
- tool_parameters: Sequence[ToolParameter],
+ datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool,
- node_data: ToolNodeData,
+ node_data: DatasourceNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
@@ -151,25 +147,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
- tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
+ datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
- for parameter_name in node_data.tool_parameters:
- parameter = tool_parameters_dictionary.get(parameter_name)
+ for parameter_name in node_data.datasource_parameters:
+ parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
- tool_input = node_data.tool_parameters[parameter_name]
- if tool_input.type == "variable":
- variable = variable_pool.get(tool_input.value)
+ datasource_input = node_data.datasource_parameters[parameter_name]
+ if datasource_input.type == "variable":
+ variable = variable_pool.get(datasource_input.value)
if variable is None:
- raise ToolParameterError(f"Variable {tool_input.value} does not exist")
+ raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
- elif tool_input.type in {"mixed", "constant"}:
- segment_group = variable_pool.convert_template(str(tool_input.value))
+ elif datasource_input.type in {"mixed", "constant"}:
+ segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
- raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
+ raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result
@@ -181,15 +177,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
def _transform_message(
self,
- messages: Generator[ToolInvokeMessage, None, None],
- tool_info: Mapping[str, Any],
+ messages: Generator[DatasourceInvokeMessage, None, None],
+ datasource_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
- message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
+ message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
@@ -207,11 +203,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
for message in message_stream:
if message.type in {
- ToolInvokeMessage.MessageType.IMAGE_LINK,
- ToolInvokeMessage.MessageType.BINARY_LINK,
- ToolInvokeMessage.MessageType.IMAGE,
+ DatasourceInvokeMessage.MessageType.IMAGE_LINK,
+ DatasourceInvokeMessage.MessageType.BINARY_LINK,
+ DatasourceInvokeMessage.MessageType.IMAGE,
}:
- assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
@@ -238,9 +234,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
tenant_id=self.tenant_id,
)
files.append(file)
- elif message.type == ToolInvokeMessage.MessageType.BLOB:
+ elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get tool file id
- assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
@@ -261,14 +257,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
tenant_id=self.tenant_id,
)
)
- elif message.type == ToolInvokeMessage.MessageType.TEXT:
- assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
+ assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
- elif message.type == ToolInvokeMessage.MessageType.JSON:
- assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
+ elif message.type == DatasourceInvokeMessage.MessageType.JSON:
+ assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
@@ -277,13 +273,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
if key in NodeRunMetadataKey.__members__.values()
}
json.append(message.message.json_object)
- elif message.type == ToolInvokeMessage.MessageType.LINK:
- assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ elif message.type == DatasourceInvokeMessage.MessageType.LINK:
+ assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
- elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
- assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
+ elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
+ assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
@@ -298,13 +294,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
)
else:
variables[variable_name] = variable_value
- elif message.type == ToolInvokeMessage.MessageType.FILE:
+ elif message.type == DatasourceInvokeMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
- elif message.type == ToolInvokeMessage.MessageType.LOG:
- assert isinstance(message.message, ToolInvokeMessage.LogMessage)
+ elif message.type == DatasourceInvokeMessage.MessageType.LOG:
+ assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
if message.message.metadata:
- icon = tool_info.get("icon", "")
+ icon = datasource_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstallationManager()
@@ -366,7 +362,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
outputs={"text": text, "files": files, "json": json, **variables},
metadata={
**agent_execution_metadata,
- NodeRunMetadataKey.TOOL_INFO: tool_info,
+ NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
@@ -379,7 +375,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: ToolNodeData,
+ node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -389,8 +385,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
:return:
"""
result = {}
- for parameter_name in node_data.tool_parameters:
- input = node_data.tool_parameters[parameter_name]
+ for parameter_name in node_data.datasource_parameters:
+ input = node_data.datasource_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
diff --git a/api/core/workflow/nodes/datasource/exc.py b/api/core/workflow/nodes/datasource/exc.py
index 7212e8bfc0..89980e6f45 100644
--- a/api/core/workflow/nodes/datasource/exc.py
+++ b/api/core/workflow/nodes/datasource/exc.py
@@ -1,16 +1,16 @@
-class ToolNodeError(ValueError):
- """Base exception for tool node errors."""
+class DatasourceNodeError(ValueError):
+ """Base exception for datasource node errors."""
pass
-class ToolParameterError(ToolNodeError):
- """Exception raised for errors in tool parameters."""
+class DatasourceParameterError(DatasourceNodeError):
+ """Exception raised for errors in datasource parameters."""
pass
-class ToolFileError(ToolNodeError):
- """Exception raised for errors related to tool files."""
+class DatasourceFileError(DatasourceNodeError):
+ """Exception raised for errors related to datasource files."""
pass
diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py
index 673d0ba049..7edc73b6ba 100644
--- a/api/core/workflow/nodes/enums.py
+++ b/api/core/workflow/nodes/enums.py
@@ -7,6 +7,7 @@ class NodeType(StrEnum):
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
+ KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py
new file mode 100644
index 0000000000..01d59b87b2
--- /dev/null
+++ b/api/core/workflow/nodes/knowledge_index/__init__.py
@@ -0,0 +1,3 @@
+from .knowledge_index_node import KnowledgeRetrievalNode
+
+__all__ = ["KnowledgeRetrievalNode"]
diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py
new file mode 100644
index 0000000000..a87032dba6
--- /dev/null
+++ b/api/core/workflow/nodes/knowledge_index/entities.py
@@ -0,0 +1,147 @@
+from collections.abc import Sequence
+from typing import Any, Literal, Optional, Union
+
+from pydantic import BaseModel, Field
+
+from core.workflow.nodes.base import BaseNodeData
+from core.workflow.nodes.llm.entities import VisionConfig
+
+
+class RerankingModelConfig(BaseModel):
+ """
+ Reranking Model Config.
+ """
+
+ provider: str
+ model: str
+
+class VectorSetting(BaseModel):
+ """
+ Vector Setting.
+ """
+
+ vector_weight: float
+ embedding_provider_name: str
+ embedding_model_name: str
+
+
+class KeywordSetting(BaseModel):
+ """
+ Keyword Setting.
+ """
+
+ keyword_weight: float
+
+class WeightedScoreConfig(BaseModel):
+ """
+ Weighted score Config.
+ """
+
+ vector_setting: VectorSetting
+ keyword_setting: KeywordSetting
+
+
+class EmbeddingSetting(BaseModel):
+ """
+ Embedding Setting.
+ """
+ embedding_provider_name: str
+ embedding_model_name: str
+
+
+class EconomySetting(BaseModel):
+ """
+ Economy Setting.
+ """
+
+ keyword_number: int
+
+
+class RetrievalSetting(BaseModel):
+ """
+ Retrieval Setting.
+ """
+ search_method: Literal["semantic_search", "keyword_search", "hybrid_search"]
+ top_k: int
+ score_threshold: Optional[float] = 0.5
+ score_threshold_enabled: bool = False
+ reranking_mode: str = "reranking_model"
+ reranking_enable: bool = True
+ reranking_model: Optional[RerankingModelConfig] = None
+ weights: Optional[WeightedScoreConfig] = None
+
+class IndexMethod(BaseModel):
+ """
+ Knowledge Index Setting.
+ """
+ indexing_technique: Literal["high_quality", "economy"]
+ embedding_setting: EmbeddingSetting
+ economy_setting: EconomySetting
+
+class FileInfo(BaseModel):
+ """
+ File Info.
+ """
+ file_id: str
+
+class OnlineDocumentIcon(BaseModel):
+ """
+ Document Icon.
+ """
+ icon_url: str
+ icon_type: str
+ icon_emoji: str
+
+class OnlineDocumentInfo(BaseModel):
+ """
+ Online document info.
+ """
+ provider: str
+ workspace_id: str
+ page_id: str
+ page_type: str
+ icon: OnlineDocumentIcon
+
+class WebsiteInfo(BaseModel):
+ """
+ website import info.
+ """
+ provider: str
+ url: str
+
+class GeneralStructureChunk(BaseModel):
+ """
+ General Structure Chunk.
+ """
+ general_chunk: list[str]
+ data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
+
+
+class ParentChildChunk(BaseModel):
+ """
+ Parent Child Chunk.
+ """
+ parent_content: str
+ child_content: list[str]
+
+
+class ParentChildStructureChunk(BaseModel):
+ """
+ Parent Child Structure Chunk.
+ """
+ parent_child_chunks: list[ParentChildChunk]
+ data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
+
+
+class KnowledgeIndexNodeData(BaseNodeData):
+ """
+ Knowledge index Node Data.
+ """
+
+ type: str = "knowledge-index"
+ dataset_id: str
+ index_chunk_variable_selector: list[str]
+ chunk_structure: Literal["general", "parent-child"]
+ index_method: IndexMethod
+ retrieval_setting: RetrievalSetting
+
diff --git a/api/core/workflow/nodes/knowledge_index/exc.py b/api/core/workflow/nodes/knowledge_index/exc.py
new file mode 100644
index 0000000000..afdde9c0c5
--- /dev/null
+++ b/api/core/workflow/nodes/knowledge_index/exc.py
@@ -0,0 +1,22 @@
+class KnowledgeIndexNodeError(ValueError):
+ """Base class for KnowledgeIndexNode errors."""
+
+
+class ModelNotExistError(KnowledgeIndexNodeError):
+ """Raised when the model does not exist."""
+
+
+class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
+ """Raised when the model credentials are not initialized."""
+
+
+class ModelNotSupportedError(KnowledgeIndexNodeError):
+ """Raised when the model is not supported."""
+
+
+class ModelQuotaExceededError(KnowledgeIndexNodeError):
+ """Raised when the model provider quota is exceeded."""
+
+
+class InvalidModelTypeError(KnowledgeIndexNodeError):
+ """Raised when the model is not a Large Language Model."""
diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
new file mode 100644
index 0000000000..543a170fa7
--- /dev/null
+++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
@@ -0,0 +1,154 @@
+import json
+import logging
+import re
+import time
+from collections import defaultdict
+from collections.abc import Mapping, Sequence
+from typing import Any, Optional, cast
+
+from sqlalchemy import Integer, and_, func, or_, text
+from sqlalchemy import cast as sqlalchemy_cast
+
+from core.app.app_config.entities import DatasetRetrieveConfigEntity
+from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
+from core.entities.agent_entities import PlanningStrategy
+from core.entities.model_entities import ModelStatus
+from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.message_entities import PromptMessageRole
+from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.prompt.simple_prompt_transform import ModelMode
+from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.entities.metadata_entities import Condition, MetadataCondition
+from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
+from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from core.variables import StringSegment
+from core.variables.segments import ObjectSegment
+from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
+from core.workflow.nodes.knowledge_retrieval.template_prompts import (
+ METADATA_FILTER_ASSISTANT_PROMPT_1,
+ METADATA_FILTER_ASSISTANT_PROMPT_2,
+ METADATA_FILTER_COMPLETION_PROMPT,
+ METADATA_FILTER_SYSTEM_PROMPT,
+ METADATA_FILTER_USER_PROMPT_1,
+ METADATA_FILTER_USER_PROMPT_3,
+)
+from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
+from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
+from extensions.ext_database import db
+from extensions.ext_redis import redis_client
+from libs.json_in_md_parser import parse_and_check_json_markdown
+from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
+from models.workflow import WorkflowNodeExecutionStatus
+from services.dataset_service import DatasetService
+from services.feature_service import FeatureService
+
+from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig
+from .exc import (
+ InvalidModelTypeError,
+ KnowledgeIndexNodeError,
+ KnowledgeRetrievalNodeError,
+ ModelCredentialsNotInitializedError,
+ ModelNotExistError,
+ ModelNotSupportedError,
+ ModelQuotaExceededError,
+)
+
+logger = logging.getLogger(__name__)
+
+default_retrieval_model = {
+ "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
+ "reranking_enable": False,
+ "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
+ "top_k": 2,
+ "score_threshold_enabled": False,
+}
+
+
+class KnowledgeIndexNode(LLMNode):
+ _node_data_cls = KnowledgeIndexNodeData # type: ignore
+ _node_type = NodeType.KNOWLEDGE_INDEX
+
+ def _run(self) -> NodeRunResult: # type: ignore
+ node_data = cast(KnowledgeIndexNodeData, self.node_data)
+ # extract variables
+ variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector)
+ if not isinstance(variable, ObjectSegment):
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs={},
+ error="Query variable is not object type.",
+ )
+ chunks = variable.value
+ variables = {"chunks": chunks}
+ if not chunks:
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
+ )
+ # check rate limit
+ if self.tenant_id:
+ knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
+ if knowledge_rate_limit.enabled:
+ current_time = int(time.time() * 1000)
+ key = f"rate_limit_{self.tenant_id}"
+ redis_client.zadd(key, {current_time: current_time})
+ redis_client.zremrangebyscore(key, 0, current_time - 60000)
+ request_count = redis_client.zcard(key)
+ if request_count > knowledge_rate_limit.limit:
+ # add ratelimit record
+ rate_limit_log = RateLimitLog(
+ tenant_id=self.tenant_id,
+ subscription_plan=knowledge_rate_limit.subscription_plan,
+ operation="knowledge",
+ )
+ db.session.add(rate_limit_log)
+ db.session.commit()
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=variables,
+ error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
+ error_type="RateLimitExceeded",
+ )
+
+ # retrieve knowledge
+ try:
+ results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks)
+ outputs = {"result": results}
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
+ )
+
+ except KnowledgeIndexNodeError as e:
+ logger.warning("Error when running knowledge index node")
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=variables,
+ error=str(e),
+ error_type=type(e).__name__,
+ )
+ # Temporary handle all exceptions from DatasetRetrieval class here.
+ except Exception as e:
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=variables,
+ error=str(e),
+ error_type=type(e).__name__,
+ )
+
+
+ def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any:
+ dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
+ if not dataset:
+ raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")
+
+ DatasetService.invoke_knowledge_index(
+ dataset=dataset,
+ chunks=chunks,
+ index_method=node_data.index_method,
+ retrieval_setting=node_data.retrieval_setting,
+ )
+
+ pass
diff --git a/api/core/workflow/nodes/knowledge_index/template_prompts.py b/api/core/workflow/nodes/knowledge_index/template_prompts.py
new file mode 100644
index 0000000000..7abd55d798
--- /dev/null
+++ b/api/core/workflow/nodes/knowledge_index/template_prompts.py
@@ -0,0 +1,66 @@
+METADATA_FILTER_SYSTEM_PROMPT = """
+ ### Job Description',
+ You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
+ ### Task
+ Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+ ### Format
+ The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
+ ### Constraint
+ DO NOT include anything other than the JSON array in your response.
+""" # noqa: E501
+
+METADATA_FILTER_USER_PROMPT_1 = """
+ { "input_text": "I want to know which company’s email address test@example.com is?",
+ "metadata_fields": ["filename", "email", "phone", "address"]
+ }
+"""
+
+METADATA_FILTER_ASSISTANT_PROMPT_1 = """
+```json
+ {"metadata_map": [
+ {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
+ ]
+ }
+```
+"""
+
+METADATA_FILTER_USER_PROMPT_2 = """
+ {"input_text": "What are the movies with a score of more than 9 in 2024?",
+ "metadata_fields": ["name", "year", "rating", "country"]}
+"""
+
+METADATA_FILTER_ASSISTANT_PROMPT_2 = """
+```json
+ {"metadata_map": [
+ {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
+ {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
+ ]}
+```
+"""
+
+METADATA_FILTER_USER_PROMPT_3 = """
+ '{{"input_text": "{input_text}",',
+ '"metadata_fields": {metadata_fields}}}'
+"""
+
+METADATA_FILTER_COMPLETION_PROMPT = """
+### Job Description
+You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
+### Task
+# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+### Format
+The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
+### Constraint
+DO NOT include anything other than the JSON array in your response.
+### Example
+Here is the chat example between human and assistant, inside XML tags.
+
+User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
+Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
+User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
+Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
+
+### User Input
+{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
+### Assistant Output
+""" # noqa: E501
diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py
index d2e5a15545..17b3308a06 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/entities.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py
@@ -59,7 +59,6 @@ class MultipleRetrievalConfig(BaseModel):
class ModelConfig(BaseModel):
"""
Model Config.
- """
provider: str
name: str
diff --git a/api/models/dataset.py b/api/models/dataset.py
index a344ab2964..3c44fb4b45 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -59,6 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
+ keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index b019cf6b63..19962d66b9 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -21,6 +21,7 @@ from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
+from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting
from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted
from extensions.ext_database import db
@@ -1131,6 +1132,408 @@ class DocumentService:
return documents, batch
@staticmethod
+ def save_document_with_dataset_id(
+ dataset: Dataset,
+ knowledge_config: KnowledgeConfig,
+ account: Account | Any,
+ dataset_process_rule: Optional[DatasetProcessRule] = None,
+ created_from: str = "web",
+ ):
+ # check document limit
+ features = FeatureService.get_features(current_user.current_tenant_id)
+
+ if features.billing.enabled:
+ if not knowledge_config.original_document_id:
+ count = 0
+ if knowledge_config.data_source:
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file":
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
+ count = len(upload_file_list)
+ elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
+ notion_info_list = knowledge_config.data_source.info_list.notion_info_list
+ for notion_info in notion_info_list: # type: ignore
+ count = count + len(notion_info.pages)
+ elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
+ website_info = knowledge_config.data_source.info_list.website_info_list
+ count = len(website_info.urls) # type: ignore
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+
+ if features.billing.subscription.plan == "sandbox" and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+
+ DocumentService.check_documents_upload_quota(count, features)
+
+ # if dataset is empty, update dataset data_source_type
+ if not dataset.data_source_type:
+ dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
+
+ if not dataset.indexing_technique:
+ if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
+ raise ValueError("Indexing technique is invalid")
+
+ dataset.indexing_technique = knowledge_config.indexing_technique
+ if knowledge_config.indexing_technique == "high_quality":
+ model_manager = ModelManager()
+ if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
+ dataset_embedding_model = knowledge_config.embedding_model
+ dataset_embedding_model_provider = knowledge_config.embedding_model_provider
+ else:
+ embedding_model = model_manager.get_default_model_instance(
+ tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
+ )
+ dataset_embedding_model = embedding_model.model
+ dataset_embedding_model_provider = embedding_model.provider
+ dataset.embedding_model = dataset_embedding_model
+ dataset.embedding_model_provider = dataset_embedding_model_provider
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ dataset_embedding_model_provider, dataset_embedding_model
+ )
+ dataset.collection_binding_id = dataset_collection_binding.id
+ if not dataset.retrieval_model:
+ default_retrieval_model = {
+ "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
+ "reranking_enable": False,
+ "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
+ "top_k": 2,
+ "score_threshold_enabled": False,
+ }
+
+ dataset.retrieval_model = (
+ knowledge_config.retrieval_model.model_dump()
+ if knowledge_config.retrieval_model
+ else default_retrieval_model
+ ) # type: ignore
+
+ documents = []
+ if knowledge_config.original_document_id:
+ document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
+ documents.append(document)
+ batch = document.batch
+ else:
+ batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
+ # save process rule
+ if not dataset_process_rule:
+ process_rule = knowledge_config.process_rule
+ if process_rule:
+ if process_rule.mode in ("custom", "hierarchical"):
+ dataset_process_rule = DatasetProcessRule(
+ dataset_id=dataset.id,
+ mode=process_rule.mode,
+ rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
+ created_by=account.id,
+ )
+ elif process_rule.mode == "automatic":
+ dataset_process_rule = DatasetProcessRule(
+ dataset_id=dataset.id,
+ mode=process_rule.mode,
+ rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
+ created_by=account.id,
+ )
+ else:
+ logging.warn(
+ f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
+ )
+ return
+ db.session.add(dataset_process_rule)
+ db.session.commit()
+ lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
+ with redis_client.lock(lock_name, timeout=600):
+ position = DocumentService.get_documents_position(dataset.id)
+ document_ids = []
+ duplicate_document_ids = []
+ if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
+ upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
+ for file_id in upload_file_list:
+ file = (
+ db.session.query(UploadFile)
+ .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
+ .first()
+ )
+
+ # raise error if file not found
+ if not file:
+ raise FileNotExistsError()
+
+ file_name = file.name
+ data_source_info = {
+ "upload_file_id": file_id,
+ }
+ # check duplicate
+ if knowledge_config.duplicate:
+ document = Document.query.filter_by(
+ dataset_id=dataset.id,
+ tenant_id=current_user.current_tenant_id,
+ data_source_type="upload_file",
+ enabled=True,
+ name=file_name,
+ ).first()
+ if document:
+ document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
+ document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.created_from = created_from
+ document.doc_form = knowledge_config.doc_form
+ document.doc_language = knowledge_config.doc_language
+ document.data_source_info = json.dumps(data_source_info)
+ document.batch = batch
+ document.indexing_status = "waiting"
+ db.session.add(document)
+ documents.append(document)
+ duplicate_document_ids.append(document.id)
+ continue
+ document = DocumentService.build_document(
+ dataset,
+ dataset_process_rule.id, # type: ignore
+ knowledge_config.data_source.info_list.data_source_type, # type: ignore
+ knowledge_config.doc_form,
+ knowledge_config.doc_language,
+ data_source_info,
+ created_from,
+ position,
+ account,
+ file_name,
+ batch,
+ )
+ db.session.add(document)
+ db.session.flush()
+ document_ids.append(document.id)
+ documents.append(document)
+ position += 1
+ elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
+ notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
+ if not notion_info_list:
+ raise ValueError("No notion info list found.")
+ exist_page_ids = []
+ exist_document = {}
+ documents = Document.query.filter_by(
+ dataset_id=dataset.id,
+ tenant_id=current_user.current_tenant_id,
+ data_source_type="notion_import",
+ enabled=True,
+ ).all()
+ if documents:
+ for document in documents:
+ data_source_info = json.loads(document.data_source_info)
+ exist_page_ids.append(data_source_info["notion_page_id"])
+ exist_document[data_source_info["notion_page_id"]] = document.id
+ for notion_info in notion_info_list:
+ workspace_id = notion_info.workspace_id
+ data_source_binding = DataSourceOauthBinding.query.filter(
+ db.and_(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.disabled == False,
+ DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
+ )
+ ).first()
+ if not data_source_binding:
+ raise ValueError("Data source binding not found.")
+ for page in notion_info.pages:
+ if page.page_id not in exist_page_ids:
+ data_source_info = {
+ "notion_workspace_id": workspace_id,
+ "notion_page_id": page.page_id,
+ "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
+ "type": page.type,
+ }
+ # Truncate page name to 255 characters to prevent DB field length errors
+ truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
+ document = DocumentService.build_document(
+ dataset,
+ dataset_process_rule.id, # type: ignore
+ knowledge_config.data_source.info_list.data_source_type, # type: ignore
+ knowledge_config.doc_form,
+ knowledge_config.doc_language,
+ data_source_info,
+ created_from,
+ position,
+ account,
+ truncated_page_name,
+ batch,
+ )
+ db.session.add(document)
+ db.session.flush()
+ document_ids.append(document.id)
+ documents.append(document)
+ position += 1
+ else:
+ exist_document.pop(page.page_id)
+ # delete not selected documents
+ if len(exist_document) > 0:
+ clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
+ elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
+ website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
+ if not website_info:
+ raise ValueError("No website info list found.")
+ urls = website_info.urls
+ for url in urls:
+ data_source_info = {
+ "url": url,
+ "provider": website_info.provider,
+ "job_id": website_info.job_id,
+ "only_main_content": website_info.only_main_content,
+ "mode": "crawl",
+ }
+ if len(url) > 255:
+ document_name = url[:200] + "..."
+ else:
+ document_name = url
+ document = DocumentService.build_document(
+ dataset,
+ dataset_process_rule.id, # type: ignore
+ knowledge_config.data_source.info_list.data_source_type, # type: ignore
+ knowledge_config.doc_form,
+ knowledge_config.doc_language,
+ data_source_info,
+ created_from,
+ position,
+ account,
+ document_name,
+ batch,
+ )
+ db.session.add(document)
+ db.session.flush()
+ document_ids.append(document.id)
+ documents.append(document)
+ position += 1
+ db.session.commit()
+
+ # trigger async task
+ if document_ids:
+ document_indexing_task.delay(dataset.id, document_ids)
+ if duplicate_document_ids:
+ duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
+
+ return documents, batch
+
+ @staticmethod
+ def invoke_knowledge_index(
+ dataset: Dataset,
+ chunks: list[Any],
+ index_method: IndexMethod,
+ retrieval_setting: RetrievalSetting,
+ original_document_id: str | None = None,
+ account: Account | Any,
+ created_from: str = "rag-pipline",
+ ):
+
+ if not dataset.indexing_technique:
+ if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
+ raise ValueError("Indexing technique is invalid")
+
+ dataset.indexing_technique = index_method.indexing_technique
+ if index_method.indexing_technique == "high_quality":
+ model_manager = ModelManager()
+ if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider:
+ dataset_embedding_model = index_method.embedding_setting.embedding_model
+ dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
+ else:
+ embedding_model = model_manager.get_default_model_instance(
+ tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
+ )
+ dataset_embedding_model = embedding_model.model
+ dataset_embedding_model_provider = embedding_model.provider
+ dataset.embedding_model = dataset_embedding_model
+ dataset.embedding_model_provider = dataset_embedding_model_provider
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ dataset_embedding_model_provider, dataset_embedding_model
+ )
+ dataset.collection_binding_id = dataset_collection_binding.id
+ if not dataset.retrieval_model:
+ default_retrieval_model = {
+ "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
+ "reranking_enable": False,
+ "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
+ "top_k": 2,
+ "score_threshold_enabled": False,
+ }
+
+ dataset.retrieval_model = (
+ retrieval_setting.model_dump()
+ if retrieval_setting
+ else default_retrieval_model
+ ) # type: ignore
+
+ documents = []
+ if original_document_id:
+ document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
+ documents.append(document)
+ batch = document.batch
+ else:
+ batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
+
+ lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
+ with redis_client.lock(lock_name, timeout=600):
+ position = DocumentService.get_documents_position(dataset.id)
+ document_ids = []
+ duplicate_document_ids = []
+ for chunk in chunks:
+ file = (
+ db.session.query(UploadFile)
+ .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
+ .first()
+ )
+
+ # raise error if file not found
+ if not file:
+ raise FileNotExistsError()
+
+ file_name = file.name
+ data_source_info = {
+ "upload_file_id": file_id,
+ }
+ # check duplicate
+ if knowledge_config.duplicate:
+ document = Document.query.filter_by(
+ dataset_id=dataset.id,
+ tenant_id=current_user.current_tenant_id,
+ data_source_type="upload_file",
+ enabled=True,
+ name=file_name,
+ ).first()
+ if document:
+ document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
+ document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.created_from = created_from
+ document.doc_form = knowledge_config.doc_form
+ document.doc_language = knowledge_config.doc_language
+ document.data_source_info = json.dumps(data_source_info)
+ document.batch = batch
+ document.indexing_status = "waiting"
+ db.session.add(document)
+ documents.append(document)
+ duplicate_document_ids.append(document.id)
+ continue
+ document = DocumentService.build_document(
+ dataset,
+ dataset_process_rule.id, # type: ignore
+ knowledge_config.data_source.info_list.data_source_type, # type: ignore
+ knowledge_config.doc_form,
+ knowledge_config.doc_language,
+ data_source_info,
+ created_from,
+ position,
+ account,
+ file_name,
+ batch,
+ )
+ db.session.add(document)
+ db.session.flush()
+ document_ids.append(document.id)
+ documents.append(document)
+ position += 1
+
+ db.session.commit()
+
+ # trigger async task
+ if document_ids:
+ document_indexing_task.delay(dataset.id, document_ids)
+ if duplicate_document_ids:
+ duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
+
+ return documents, batch
+ @staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size: