This commit is contained in:
jyong 2025-04-28 16:19:12 +08:00
parent d4007ae073
commit 49d1846e63
13 changed files with 902 additions and 109 deletions

View File

@ -3,58 +3,58 @@ from collections.abc import Generator
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional from typing import Optional
from core.datasource.datasource_file_manager import DatasourceFileManager
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolFileMessageTransformer: class DatasourceFileMessageTransformer:
@classmethod @classmethod
def transform_tool_invoke_messages( def transform_datasource_invoke_messages(
cls, cls,
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[DatasourceInvokeMessage, None, None],
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[DatasourceInvokeMessage, None, None]:
""" """
Transform tool message and handle file download Transform datasource message and handle file download
""" """
for message in messages: for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}:
yield message yield message
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance( elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance(
message.message, ToolInvokeMessage.TextMessage message.message, DatasourceInvokeMessage.TextMessage
): ):
# try to download image # try to download image
try: try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
file = ToolFileManager.create_file_by_url( file = DatasourceFileManager.create_file_by_url(
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
file_url=message.message.text, file_url=message.message.text,
conversation_id=conversation_id, conversation_id=conversation_id,
) )
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
yield ToolInvokeMessage( yield DatasourceInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url), message=DatasourceInvokeMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
except Exception as e: except Exception as e:
yield ToolInvokeMessage( yield DatasourceInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, type=DatasourceInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage( message=DatasourceInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}" text=f"Failed to download image: {message.message.text}: {e}"
), ),
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
elif message.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage # get mime type and save blob to storage
meta = message.meta or {} meta = message.meta or {}
@ -63,12 +63,12 @@ class ToolFileMessageTransformer:
filename = meta.get("file_name", None) filename = meta.get("file_name", None)
# if message is str, encode it to bytes # if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage): if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage):
raise ValueError("unexpected message type") raise ValueError("unexpected message type")
# FIXME: should do a type check here. # FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes) assert isinstance(message.message.blob, bytes)
file = ToolFileManager.create_file_by_raw( file = DatasourceFileManager.create_file_by_raw(
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
conversation_id=conversation_id, conversation_id=conversation_id,
@ -77,22 +77,22 @@ class ToolFileMessageTransformer:
filename=filename, filename=filename,
) )
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image # check if file is image
if "image" in mimetype: if "image" in mimetype:
yield ToolInvokeMessage( yield DatasourceInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url), message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
else: else:
yield ToolInvokeMessage( yield DatasourceInvokeMessage(
type=ToolInvokeMessage.MessageType.BINARY_LINK, type=DatasourceInvokeMessage.MessageType.BINARY_LINK,
message=ToolInvokeMessage.TextMessage(text=url), message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
elif message.type == ToolInvokeMessage.MessageType.FILE: elif message.type == DatasourceInvokeMessage.MessageType.FILE:
meta = message.meta or {} meta = message.meta or {}
file = meta.get("file", None) file = meta.get("file", None)
if isinstance(file, File): if isinstance(file, File):
@ -100,15 +100,15 @@ class ToolFileMessageTransformer:
assert file.related_id is not None assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE: if file.type == FileType.IMAGE:
yield ToolInvokeMessage( yield DatasourceInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url), message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
else: else:
yield ToolInvokeMessage( yield DatasourceInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, type=DatasourceInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url), message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
else: else:
@ -117,5 +117,5 @@ class ToolFileMessageTransformer:
yield message yield message
@classmethod @classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
return f"/files/tools/{tool_file_id}{extension or '.bin'}" return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@ -88,13 +88,14 @@ class PluginDatasourceManager(BasePluginManager):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step", f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages",
ToolInvokeMessage, ToolInvokeMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
"provider": datasource_provider_id.provider_name, "provider": datasource_provider_id.provider_name,
"datasource": datasource_name, "datasource": datasource_name,
"credentials": credentials, "credentials": credentials,
"datasource_parameters": datasource_parameters, "datasource_parameters": datasource_parameters,
}, },

View File

@ -5,13 +5,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.datasource.datasource_engine import DatasourceEngine
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
from core.datasource.errors import DatasourceInvokeError
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
@ -29,11 +29,7 @@ from models.workflow import WorkflowNodeExecutionStatus
from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import DatasourceNodeData from .entities import DatasourceNodeData
from .exc import ( from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError
ToolFileError,
ToolNodeError,
ToolParameterError,
)
class DatasourceNode(BaseNode[DatasourceNodeData]): class DatasourceNode(BaseNode[DatasourceNodeData]):
@ -60,12 +56,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
# get datasource runtime # get datasource runtime
try: try:
from core.tools.tool_manager import ToolManager from core.datasource.datasource_manager import DatasourceManager
tool_runtime = ToolManager.get_workflow_tool_runtime( datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
) )
except ToolNodeError as e: except DatasourceNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -78,14 +74,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return return
# get parameters # get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters( parameters = self._generate_parameters(
tool_parameters=tool_parameters, datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self.node_data,
) )
parameters_for_log = self._generate_parameters( parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters, datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self.node_data,
for_log=True, for_log=True,
@ -95,9 +91,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try: try:
message_stream = ToolEngine.generic_invoke( message_stream = DatasourceEngine.generic_invoke(
tool=tool_runtime, datasource=datasource_runtime,
tool_parameters=parameters, datasource_parameters=parameters,
user_id=self.user_id, user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth, workflow_call_depth=self.workflow_call_depth,
@ -105,28 +101,28 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
app_id=self.app_id, app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None, conversation_id=conversation_id.text if conversation_id else None,
) )
except ToolNodeError as e: except DatasourceNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke tool: {str(e)}", error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
return return
try: try:
# convert tool messages # convert datasource messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log) yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
except (PluginDaemonClientSideError, ToolInvokeError) as e: except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform tool message: {str(e)}", error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
@ -134,9 +130,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
def _generate_parameters( def _generate_parameters(
self, self,
*, *,
tool_parameters: Sequence[ToolParameter], datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool, variable_pool: VariablePool,
node_data: ToolNodeData, node_data: DatasourceNodeData,
for_log: bool = False, for_log: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
@ -151,25 +147,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
Mapping[str, Any]: A dictionary containing the generated parameters. Mapping[str, Any]: A dictionary containing the generated parameters.
""" """
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters: for parameter_name in node_data.datasource_parameters:
parameter = tool_parameters_dictionary.get(parameter_name) parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter: if not parameter:
result[parameter_name] = None result[parameter_name] = None
continue continue
tool_input = node_data.tool_parameters[parameter_name] datasource_input = node_data.datasource_parameters[parameter_name]
if tool_input.type == "variable": if datasource_input.type == "variable":
variable = variable_pool.get(tool_input.value) variable = variable_pool.get(datasource_input.value)
if variable is None: if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist") raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}: elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value)) segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text parameter_value = segment_group.log if for_log else segment_group.text
else: else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value result[parameter_name] = parameter_value
return result return result
@ -181,15 +177,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
def _transform_message( def _transform_message(
self, self,
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[DatasourceInvokeMessage, None, None],
tool_info: Mapping[str, Any], datasource_info: Mapping[str, Any],
parameters_for_log: dict[str, Any], parameters_for_log: dict[str, Any],
) -> Generator: ) -> Generator:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
""" """
# transform message and handle file storage # transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages, messages=messages,
user_id=self.user_id, user_id=self.user_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -207,11 +203,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
for message in message_stream: for message in message_stream:
if message.type in { if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK, DatasourceInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE, DatasourceInvokeMessage.MessageType.IMAGE,
}: }:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
url = message.message.text url = message.message.text
if message.meta: if message.meta:
@ -238,9 +234,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
files.append(file) files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB: elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get tool file id # get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
assert message.meta assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0] tool_file_id = message.message.text.split("/")[-1].split(".")[0]
@ -261,14 +257,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
) )
elif message.type == ToolInvokeMessage.MessageType.TEXT: elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
text += message.message.text text += message.message.text
yield RunStreamChunkEvent( yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
) )
elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == DatasourceInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT: if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {}) msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = { agent_execution_metadata = {
@ -277,13 +273,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
if key in NodeRunMetadataKey.__members__.values() if key in NodeRunMetadataKey.__members__.values()
} }
json.append(message.message.json_object) json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK: elif message.type == DatasourceInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n" stream_text = f"Link: {message.message.text}\n"
text += stream_text text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE: elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage) assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
variable_name = message.message.variable_name variable_name = message.message.variable_name
variable_value = message.message.variable_value variable_value = message.message.variable_value
if message.message.stream: if message.message.stream:
@ -298,13 +294,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
) )
else: else:
variables[variable_name] = variable_value variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE: elif message.type == DatasourceInvokeMessage.MessageType.FILE:
assert message.meta is not None assert message.meta is not None
files.append(message.meta["file"]) files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG: elif message.type == DatasourceInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage) assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
if message.message.metadata: if message.message.metadata:
icon = tool_info.get("icon", "") icon = datasource_info.get("icon", "")
dict_metadata = dict(message.message.metadata) dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"): if dict_metadata.get("provider"):
manager = PluginInstallationManager() manager = PluginInstallationManager()
@ -366,7 +362,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
outputs={"text": text, "files": files, "json": json, **variables}, outputs={"text": text, "files": files, "json": json, **variables},
metadata={ metadata={
**agent_execution_metadata, **agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info, NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs, NodeRunMetadataKey.AGENT_LOG: agent_logs,
}, },
inputs=parameters_for_log, inputs=parameters_for_log,
@ -379,7 +375,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: ToolNodeData, node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
@ -389,8 +385,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
:return: :return:
""" """
result = {} result = {}
for parameter_name in node_data.tool_parameters: for parameter_name in node_data.datasource_parameters:
input = node_data.tool_parameters[parameter_name] input = node_data.datasource_parameters[parameter_name]
if input.type == "mixed": if input.type == "mixed":
assert isinstance(input.value, str) assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors() selectors = VariableTemplateParser(input.value).extract_variable_selectors()

View File

@ -1,16 +1,16 @@
class ToolNodeError(ValueError): class DatasourceNodeError(ValueError):
"""Base exception for tool node errors.""" """Base exception for datasource node errors."""
pass pass
class ToolParameterError(ToolNodeError): class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in tool parameters.""" """Exception raised for errors in datasource parameters."""
pass pass
class ToolFileError(ToolNodeError): class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to tool files.""" """Exception raised for errors related to datasource files."""
pass pass

View File

@ -7,6 +7,7 @@ class NodeType(StrEnum):
ANSWER = "answer" ANSWER = "answer"
LLM = "llm" LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else" IF_ELSE = "if-else"
CODE = "code" CODE = "code"
TEMPLATE_TRANSFORM = "template-transform" TEMPLATE_TRANSFORM = "template-transform"

View File

@ -0,0 +1,3 @@
from .knowledge_index_node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode"]

View File

@ -0,0 +1,147 @@
from collections.abc import Sequence
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
provider: str
model: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str
class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str
class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str
page_id: str
page_type: str
icon: OnlineDocumentIcon
class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunk: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_content: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""
type: str = "knowledge-index"
dataset_id: str
index_chunk_variable_selector: list[str]
chunk_structure: Literal["general", "parent-child"]
index_method: IndexMethod
retrieval_setting: RetrievalSetting

View File

@ -0,0 +1,22 @@
class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""
class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""
class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""
class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""
class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

View File

@ -0,0 +1,154 @@
import json
import logging
import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.dataset_service import DatasetService
from services.feature_service import FeatureService
from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig
from .exc import (
InvalidModelTypeError,
KnowledgeIndexNodeError,
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class KnowledgeIndexNode(LLMNode):
_node_data_cls = KnowledgeIndexNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_INDEX
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector)
if not isinstance(variable, ObjectSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not object type.",
)
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
try:
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)
except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any:
dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")
DatasetService.invoke_knowledge_index(
dataset=dataset,
chunks=chunks,
index_method=node_data.index_method,
retrieval_setting=node_data.retrieval_setting,
)
pass

View File

@ -0,0 +1,66 @@
METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501
METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which companys email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""
METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""
METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""
METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which companys email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

View File

@ -59,7 +59,6 @@ class MultipleRetrievalConfig(BaseModel):
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
""" """
Model Config. Model Config.
"""
provider: str provider: str
name: str name: str

View File

@ -59,6 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = db.Column(db.String(255), nullable=True) embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = db.Column(StringUUID, nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True) retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))

View File

@ -21,6 +21,7 @@ from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
@ -1131,6 +1132,408 @@ class DocumentService:
return documents, batch return documents, batch
@staticmethod @staticmethod
def save_document_with_dataset_id(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
# check document limit
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if not knowledge_config.original_document_id:
count = 0
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
DocumentService.check_documents_upload_quota(count, features)
# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore
documents = []
if knowledge_config.original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
# save process rule
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@staticmethod
def invoke_knowledge_index(
dataset: Dataset,
chunks: list[Any],
index_method: IndexMethod,
retrieval_setting: RetrievalSetting,
original_document_id: str | None = None,
account: Account | Any,
created_from: str = "rag-pipline",
):
if not dataset.indexing_technique:
if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")
dataset.indexing_technique = index_method.indexing_technique
if index_method.indexing_technique == "high_quality":
model_manager = ModelManager()
if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider:
dataset_embedding_model = index_method.embedding_setting.embedding_model
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
dataset.retrieval_model = (
retrieval_setting.model_dump()
if retrieval_setting
else default_retrieval_model
) # type: ignore
documents = []
if original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
for chunk in chunks:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
# raise error if file not found
if not file:
raise FileNotExistsError()
file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()
# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel): def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size: if count > can_upload_size: