mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 11:05:53 +08:00
r2
This commit is contained in:
parent
d4007ae073
commit
49d1846e63
@ -3,58 +3,58 @@ from collections.abc import Generator
|
|||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ToolFileMessageTransformer:
|
class DatasourceFileMessageTransformer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_tool_invoke_messages(
|
def transform_datasource_invoke_messages(
|
||||||
cls,
|
cls,
|
||||||
messages: Generator[ToolInvokeMessage, None, None],
|
messages: Generator[DatasourceInvokeMessage, None, None],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Transform tool message and handle file download
|
Transform datasource message and handle file download
|
||||||
"""
|
"""
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}:
|
||||||
yield message
|
yield message
|
||||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
|
elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance(
|
||||||
message.message, ToolInvokeMessage.TextMessage
|
message.message, DatasourceInvokeMessage.TextMessage
|
||||||
):
|
):
|
||||||
# try to download image
|
# try to download image
|
||||||
try:
|
try:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||||
|
|
||||||
file = ToolFileManager.create_file_by_url(
|
file = DatasourceFileManager.create_file_by_url(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
file_url=message.message.text,
|
file_url=message.message.text,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||||
|
|
||||||
yield ToolInvokeMessage(
|
yield DatasourceInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text=url),
|
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield ToolInvokeMessage(
|
yield DatasourceInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.TEXT,
|
type=DatasourceInvokeMessage.MessageType.TEXT,
|
||||||
message=ToolInvokeMessage.TextMessage(
|
message=DatasourceInvokeMessage.TextMessage(
|
||||||
text=f"Failed to download image: {message.message.text}: {e}"
|
text=f"Failed to download image: {message.message.text}: {e}"
|
||||||
),
|
),
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||||
# get mime type and save blob to storage
|
# get mime type and save blob to storage
|
||||||
meta = message.meta or {}
|
meta = message.meta or {}
|
||||||
|
|
||||||
@ -63,12 +63,12 @@ class ToolFileMessageTransformer:
|
|||||||
filename = meta.get("file_name", None)
|
filename = meta.get("file_name", None)
|
||||||
# if message is str, encode it to bytes
|
# if message is str, encode it to bytes
|
||||||
|
|
||||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage):
|
||||||
raise ValueError("unexpected message type")
|
raise ValueError("unexpected message type")
|
||||||
|
|
||||||
# FIXME: should do a type check here.
|
# FIXME: should do a type check here.
|
||||||
assert isinstance(message.message.blob, bytes)
|
assert isinstance(message.message.blob, bytes)
|
||||||
file = ToolFileManager.create_file_by_raw(
|
file = DatasourceFileManager.create_file_by_raw(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@ -77,22 +77,22 @@ class ToolFileMessageTransformer:
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||||
|
|
||||||
# check if file is image
|
# check if file is image
|
||||||
if "image" in mimetype:
|
if "image" in mimetype:
|
||||||
yield ToolInvokeMessage(
|
yield DatasourceInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text=url),
|
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||||
meta=meta.copy() if meta is not None else {},
|
meta=meta.copy() if meta is not None else {},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield ToolInvokeMessage(
|
yield DatasourceInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.BINARY_LINK,
|
type=DatasourceInvokeMessage.MessageType.BINARY_LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text=url),
|
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||||
meta=meta.copy() if meta is not None else {},
|
meta=meta.copy() if meta is not None else {},
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
|
||||||
meta = message.meta or {}
|
meta = message.meta or {}
|
||||||
file = meta.get("file", None)
|
file = meta.get("file", None)
|
||||||
if isinstance(file, File):
|
if isinstance(file, File):
|
||||||
@ -100,15 +100,15 @@ class ToolFileMessageTransformer:
|
|||||||
assert file.related_id is not None
|
assert file.related_id is not None
|
||||||
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||||
if file.type == FileType.IMAGE:
|
if file.type == FileType.IMAGE:
|
||||||
yield ToolInvokeMessage(
|
yield DatasourceInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text=url),
|
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||||
meta=meta.copy() if meta is not None else {},
|
meta=meta.copy() if meta is not None else {},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield ToolInvokeMessage(
|
yield DatasourceInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.LINK,
|
type=DatasourceInvokeMessage.MessageType.LINK,
|
||||||
message=ToolInvokeMessage.TextMessage(text=url),
|
message=DatasourceInvokeMessage.TextMessage(text=url),
|
||||||
meta=meta.copy() if meta is not None else {},
|
meta=meta.copy() if meta is not None else {},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -117,5 +117,5 @@ class ToolFileMessageTransformer:
|
|||||||
yield message
|
yield message
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||||
|
@ -88,13 +88,14 @@ class PluginDatasourceManager(BasePluginManager):
|
|||||||
|
|
||||||
response = self._request_with_plugin_daemon_response_stream(
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step",
|
f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages",
|
||||||
ToolInvokeMessage,
|
ToolInvokeMessage,
|
||||||
data={
|
data={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"data": {
|
"data": {
|
||||||
"provider": datasource_provider_id.provider_name,
|
"provider": datasource_provider_id.provider_name,
|
||||||
"datasource": datasource_name,
|
"datasource": datasource_name,
|
||||||
|
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
"datasource_parameters": datasource_parameters,
|
"datasource_parameters": datasource_parameters,
|
||||||
},
|
},
|
||||||
|
@ -5,13 +5,13 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
|
from core.datasource.datasource_engine import DatasourceEngine
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
|
||||||
|
from core.datasource.errors import DatasourceInvokeError
|
||||||
|
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||||
from core.file import File, FileTransferMethod
|
from core.file import File, FileTransferMethod
|
||||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||||
from core.plugin.manager.plugin import PluginInstallationManager
|
from core.plugin.manager.plugin import PluginInstallationManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
|
||||||
from core.tools.errors import ToolInvokeError
|
|
||||||
from core.tools.tool_engine import ToolEngine
|
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
|
||||||
from core.variables.segments import ArrayAnySegment
|
from core.variables.segments import ArrayAnySegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
@ -29,11 +29,7 @@ from models.workflow import WorkflowNodeExecutionStatus
|
|||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|
||||||
from .entities import DatasourceNodeData
|
from .entities import DatasourceNodeData
|
||||||
from .exc import (
|
from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError
|
||||||
ToolFileError,
|
|
||||||
ToolNodeError,
|
|
||||||
ToolParameterError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||||
@ -60,12 +56,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
|
|
||||||
# get datasource runtime
|
# get datasource runtime
|
||||||
try:
|
try:
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.datasource.datasource_manager import DatasourceManager
|
||||||
|
|
||||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
|
||||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||||
)
|
)
|
||||||
except ToolNodeError as e:
|
except DatasourceNodeError as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
@ -78,14 +74,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# get parameters
|
# get parameters
|
||||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
|
||||||
parameters = self._generate_parameters(
|
parameters = self._generate_parameters(
|
||||||
tool_parameters=tool_parameters,
|
datasource_parameters=datasource_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self.node_data,
|
node_data=self.node_data,
|
||||||
)
|
)
|
||||||
parameters_for_log = self._generate_parameters(
|
parameters_for_log = self._generate_parameters(
|
||||||
tool_parameters=tool_parameters,
|
datasource_parameters=datasource_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self.node_data,
|
node_data=self.node_data,
|
||||||
for_log=True,
|
for_log=True,
|
||||||
@ -95,9 +91,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message_stream = ToolEngine.generic_invoke(
|
message_stream = DatasourceEngine.generic_invoke(
|
||||||
tool=tool_runtime,
|
datasource=datasource_runtime,
|
||||||
tool_parameters=parameters,
|
datasource_parameters=parameters,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||||
workflow_call_depth=self.workflow_call_depth,
|
workflow_call_depth=self.workflow_call_depth,
|
||||||
@ -105,28 +101,28 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
app_id=self.app_id,
|
app_id=self.app_id,
|
||||||
conversation_id=conversation_id.text if conversation_id else None,
|
conversation_id=conversation_id.text if conversation_id else None,
|
||||||
)
|
)
|
||||||
except ToolNodeError as e:
|
except DatasourceNodeError as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
error=f"Failed to invoke tool: {str(e)}",
|
error=f"Failed to invoke datasource: {str(e)}",
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# convert tool messages
|
# convert datasource messages
|
||||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
|
||||||
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
error=f"Failed to transform tool message: {str(e)}",
|
error=f"Failed to transform datasource message: {str(e)}",
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -134,9 +130,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
def _generate_parameters(
|
def _generate_parameters(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
tool_parameters: Sequence[ToolParameter],
|
datasource_parameters: Sequence[DatasourceParameter],
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
node_data: ToolNodeData,
|
node_data: DatasourceNodeData,
|
||||||
for_log: bool = False,
|
for_log: bool = False,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -151,25 +147,25 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||||
|
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, Any] = {}
|
||||||
for parameter_name in node_data.tool_parameters:
|
for parameter_name in node_data.datasource_parameters:
|
||||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||||
if not parameter:
|
if not parameter:
|
||||||
result[parameter_name] = None
|
result[parameter_name] = None
|
||||||
continue
|
continue
|
||||||
tool_input = node_data.tool_parameters[parameter_name]
|
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||||
if tool_input.type == "variable":
|
if datasource_input.type == "variable":
|
||||||
variable = variable_pool.get(tool_input.value)
|
variable = variable_pool.get(datasource_input.value)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||||
parameter_value = variable.value
|
parameter_value = variable.value
|
||||||
elif tool_input.type in {"mixed", "constant"}:
|
elif datasource_input.type in {"mixed", "constant"}:
|
||||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||||
parameter_value = segment_group.log if for_log else segment_group.text
|
parameter_value = segment_group.log if for_log else segment_group.text
|
||||||
else:
|
else:
|
||||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||||
result[parameter_name] = parameter_value
|
result[parameter_name] = parameter_value
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -181,15 +177,15 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
|
|
||||||
def _transform_message(
|
def _transform_message(
|
||||||
self,
|
self,
|
||||||
messages: Generator[ToolInvokeMessage, None, None],
|
messages: Generator[DatasourceInvokeMessage, None, None],
|
||||||
tool_info: Mapping[str, Any],
|
datasource_info: Mapping[str, Any],
|
||||||
parameters_for_log: dict[str, Any],
|
parameters_for_log: dict[str, Any],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
"""
|
"""
|
||||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
"""
|
"""
|
||||||
# transform message and handle file storage
|
# transform message and handle file storage
|
||||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@ -207,11 +203,11 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
|
|
||||||
for message in message_stream:
|
for message in message_stream:
|
||||||
if message.type in {
|
if message.type in {
|
||||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
DatasourceInvokeMessage.MessageType.BINARY_LINK,
|
||||||
ToolInvokeMessage.MessageType.IMAGE,
|
DatasourceInvokeMessage.MessageType.IMAGE,
|
||||||
}:
|
}:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||||
|
|
||||||
url = message.message.text
|
url = message.message.text
|
||||||
if message.meta:
|
if message.meta:
|
||||||
@ -238,9 +234,9 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
)
|
)
|
||||||
files.append(file)
|
files.append(file)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||||
# get tool file id
|
# get tool file id
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||||
assert message.meta
|
assert message.meta
|
||||||
|
|
||||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||||
@ -261,14 +257,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||||
text += message.message.text
|
text += message.message.text
|
||||||
yield RunStreamChunkEvent(
|
yield RunStreamChunkEvent(
|
||||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
elif message.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
|
||||||
if self.node_type == NodeType.AGENT:
|
if self.node_type == NodeType.AGENT:
|
||||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||||
agent_execution_metadata = {
|
agent_execution_metadata = {
|
||||||
@ -277,13 +273,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
if key in NodeRunMetadataKey.__members__.values()
|
if key in NodeRunMetadataKey.__members__.values()
|
||||||
}
|
}
|
||||||
json.append(message.message.json_object)
|
json.append(message.message.json_object)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
elif message.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
|
||||||
stream_text = f"Link: {message.message.text}\n"
|
stream_text = f"Link: {message.message.text}\n"
|
||||||
text += stream_text
|
text += stream_text
|
||||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
|
||||||
variable_name = message.message.variable_name
|
variable_name = message.message.variable_name
|
||||||
variable_value = message.message.variable_value
|
variable_value = message.message.variable_value
|
||||||
if message.message.stream:
|
if message.message.stream:
|
||||||
@ -298,13 +294,13 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
variables[variable_name] = variable_value
|
variables[variable_name] = variable_value
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
|
||||||
assert message.meta is not None
|
assert message.meta is not None
|
||||||
files.append(message.meta["file"])
|
files.append(message.meta["file"])
|
||||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
elif message.type == DatasourceInvokeMessage.MessageType.LOG:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
|
||||||
if message.message.metadata:
|
if message.message.metadata:
|
||||||
icon = tool_info.get("icon", "")
|
icon = datasource_info.get("icon", "")
|
||||||
dict_metadata = dict(message.message.metadata)
|
dict_metadata = dict(message.message.metadata)
|
||||||
if dict_metadata.get("provider"):
|
if dict_metadata.get("provider"):
|
||||||
manager = PluginInstallationManager()
|
manager = PluginInstallationManager()
|
||||||
@ -366,7 +362,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
outputs={"text": text, "files": files, "json": json, **variables},
|
outputs={"text": text, "files": files, "json": json, **variables},
|
||||||
metadata={
|
metadata={
|
||||||
**agent_execution_metadata,
|
**agent_execution_metadata,
|
||||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||||
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
NodeRunMetadataKey.AGENT_LOG: agent_logs,
|
||||||
},
|
},
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
@ -379,7 +375,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
*,
|
*,
|
||||||
graph_config: Mapping[str, Any],
|
graph_config: Mapping[str, Any],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_data: ToolNodeData,
|
node_data: DatasourceNodeData,
|
||||||
) -> Mapping[str, Sequence[str]]:
|
) -> Mapping[str, Sequence[str]]:
|
||||||
"""
|
"""
|
||||||
Extract variable selector to variable mapping
|
Extract variable selector to variable mapping
|
||||||
@ -389,8 +385,8 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
result = {}
|
result = {}
|
||||||
for parameter_name in node_data.tool_parameters:
|
for parameter_name in node_data.datasource_parameters:
|
||||||
input = node_data.tool_parameters[parameter_name]
|
input = node_data.datasource_parameters[parameter_name]
|
||||||
if input.type == "mixed":
|
if input.type == "mixed":
|
||||||
assert isinstance(input.value, str)
|
assert isinstance(input.value, str)
|
||||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
class ToolNodeError(ValueError):
|
class DatasourceNodeError(ValueError):
|
||||||
"""Base exception for tool node errors."""
|
"""Base exception for datasource node errors."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterError(ToolNodeError):
|
class DatasourceParameterError(DatasourceNodeError):
|
||||||
"""Exception raised for errors in tool parameters."""
|
"""Exception raised for errors in datasource parameters."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolFileError(ToolNodeError):
|
class DatasourceFileError(DatasourceNodeError):
|
||||||
"""Exception raised for errors related to tool files."""
|
"""Exception raised for errors related to datasource files."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
@ -7,6 +7,7 @@ class NodeType(StrEnum):
|
|||||||
ANSWER = "answer"
|
ANSWER = "answer"
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||||
|
KNOWLEDGE_INDEX = "knowledge-index"
|
||||||
IF_ELSE = "if-else"
|
IF_ELSE = "if-else"
|
||||||
CODE = "code"
|
CODE = "code"
|
||||||
TEMPLATE_TRANSFORM = "template-transform"
|
TEMPLATE_TRANSFORM = "template-transform"
|
||||||
|
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
3
api/core/workflow/nodes/knowledge_index/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .knowledge_index_node import KnowledgeRetrievalNode
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeRetrievalNode"]
|
147
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
147
api/core/workflow/nodes/knowledge_index/entities.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
from core.workflow.nodes.llm.entities import VisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RerankingModelConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Reranking Model Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
|
||||||
|
class VectorSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Vector Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_weight: float
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Keyword Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_weight: float
|
||||||
|
|
||||||
|
class WeightedScoreConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted score Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_setting: VectorSetting
|
||||||
|
keyword_setting: KeywordSetting
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Embedding Setting.
|
||||||
|
"""
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EconomySetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Economy Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_number: int
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Retrieval Setting.
|
||||||
|
"""
|
||||||
|
search_method: Literal["semantic_search", "keyword_search", "hybrid_search"]
|
||||||
|
top_k: int
|
||||||
|
score_threshold: Optional[float] = 0.5
|
||||||
|
score_threshold_enabled: bool = False
|
||||||
|
reranking_mode: str = "reranking_model"
|
||||||
|
reranking_enable: bool = True
|
||||||
|
reranking_model: Optional[RerankingModelConfig] = None
|
||||||
|
weights: Optional[WeightedScoreConfig] = None
|
||||||
|
|
||||||
|
class IndexMethod(BaseModel):
|
||||||
|
"""
|
||||||
|
Knowledge Index Setting.
|
||||||
|
"""
|
||||||
|
indexing_technique: Literal["high_quality", "economy"]
|
||||||
|
embedding_setting: EmbeddingSetting
|
||||||
|
economy_setting: EconomySetting
|
||||||
|
|
||||||
|
class FileInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
File Info.
|
||||||
|
"""
|
||||||
|
file_id: str
|
||||||
|
|
||||||
|
class OnlineDocumentIcon(BaseModel):
|
||||||
|
"""
|
||||||
|
Document Icon.
|
||||||
|
"""
|
||||||
|
icon_url: str
|
||||||
|
icon_type: str
|
||||||
|
icon_emoji: str
|
||||||
|
|
||||||
|
class OnlineDocumentInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document info.
|
||||||
|
"""
|
||||||
|
provider: str
|
||||||
|
workspace_id: str
|
||||||
|
page_id: str
|
||||||
|
page_type: str
|
||||||
|
icon: OnlineDocumentIcon
|
||||||
|
|
||||||
|
class WebsiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
website import info.
|
||||||
|
"""
|
||||||
|
provider: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
class GeneralStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
General Structure Chunk.
|
||||||
|
"""
|
||||||
|
general_chunk: list[str]
|
||||||
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Chunk.
|
||||||
|
"""
|
||||||
|
parent_content: str
|
||||||
|
child_content: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Structure Chunk.
|
||||||
|
"""
|
||||||
|
parent_child_chunks: list[ParentChildChunk]
|
||||||
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Knowledge index Node Data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "knowledge-index"
|
||||||
|
dataset_id: str
|
||||||
|
index_chunk_variable_selector: list[str]
|
||||||
|
chunk_structure: Literal["general", "parent-child"]
|
||||||
|
index_method: IndexMethod
|
||||||
|
retrieval_setting: RetrievalSetting
|
||||||
|
|
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
22
api/core/workflow/nodes/knowledge_index/exc.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
class KnowledgeIndexNodeError(ValueError):
|
||||||
|
"""Base class for KnowledgeIndexNode errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotExistError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model does not exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model credentials are not initialized."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotSupportedError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelQuotaExceededError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model provider quota is exceeded."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelTypeError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model is not a Large Language Model."""
|
154
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
154
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from sqlalchemy import Integer, and_, func, or_, text
|
||||||
|
from sqlalchemy import cast as sqlalchemy_cast
|
||||||
|
|
||||||
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
|
from core.entities.model_entities import ModelStatus
|
||||||
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.variables import StringSegment
|
||||||
|
from core.variables.segments import ObjectSegment
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||||
|
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||||
|
METADATA_FILTER_COMPLETION_PROMPT,
|
||||||
|
METADATA_FILTER_SYSTEM_PROMPT,
|
||||||
|
METADATA_FILTER_USER_PROMPT_1,
|
||||||
|
METADATA_FILTER_USER_PROMPT_3,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||||
|
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig
|
||||||
|
from .exc import (
|
||||||
|
InvalidModelTypeError,
|
||||||
|
KnowledgeIndexNodeError,
|
||||||
|
KnowledgeRetrievalNodeError,
|
||||||
|
ModelCredentialsNotInitializedError,
|
||||||
|
ModelNotExistError,
|
||||||
|
ModelNotSupportedError,
|
||||||
|
ModelQuotaExceededError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_retrieval_model = {
|
||||||
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
"reranking_enable": False,
|
||||||
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexNode(LLMNode):
|
||||||
|
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||||
|
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
|
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||||
|
# extract variables
|
||||||
|
variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector)
|
||||||
|
if not isinstance(variable, ObjectSegment):
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs={},
|
||||||
|
error="Query variable is not object type.",
|
||||||
|
)
|
||||||
|
chunks = variable.value
|
||||||
|
variables = {"chunks": chunks}
|
||||||
|
if not chunks:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||||
|
)
|
||||||
|
# check rate limit
|
||||||
|
if self.tenant_id:
|
||||||
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||||
|
if knowledge_rate_limit.enabled:
|
||||||
|
current_time = int(time.time() * 1000)
|
||||||
|
key = f"rate_limit_{self.tenant_id}"
|
||||||
|
redis_client.zadd(key, {current_time: current_time})
|
||||||
|
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||||
|
request_count = redis_client.zcard(key)
|
||||||
|
if request_count > knowledge_rate_limit.limit:
|
||||||
|
# add ratelimit record
|
||||||
|
rate_limit_log = RateLimitLog(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||||
|
operation="knowledge",
|
||||||
|
)
|
||||||
|
db.session.add(rate_limit_log)
|
||||||
|
db.session.commit()
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||||
|
error_type="RateLimitExceeded",
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve knowledge
|
||||||
|
try:
|
||||||
|
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks)
|
||||||
|
outputs = {"result": results}
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
except KnowledgeIndexNodeError as e:
|
||||||
|
logger.warning("Error when running knowledge index node")
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any:
|
||||||
|
dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
|
||||||
|
if not dataset:
|
||||||
|
raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")
|
||||||
|
|
||||||
|
DatasetService.invoke_knowledge_index(
|
||||||
|
dataset=dataset,
|
||||||
|
chunks=chunks,
|
||||||
|
index_method=node_data.index_method,
|
||||||
|
retrieval_setting=node_data.retrieval_setting,
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
66
api/core/workflow/nodes/knowledge_index/template_prompts.py
Normal file
66
api/core/workflow/nodes/knowledge_index/template_prompts.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
METADATA_FILTER_SYSTEM_PROMPT = """
|
||||||
|
### Job Description',
|
||||||
|
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||||
|
### Task
|
||||||
|
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||||
|
### Format
|
||||||
|
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||||
|
### Constraint
|
||||||
|
DO NOT include anything other than the JSON array in your response.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_1 = """
|
||||||
|
{ "input_text": "I want to know which company’s email address test@example.com is?",
|
||||||
|
"metadata_fields": ["filename", "email", "phone", "address"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_1 = """
|
||||||
|
```json
|
||||||
|
{"metadata_map": [
|
||||||
|
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_2 = """
|
||||||
|
{"input_text": "What are the movies with a score of more than 9 in 2024?",
|
||||||
|
"metadata_fields": ["name", "year", "rating", "country"]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_ASSISTANT_PROMPT_2 = """
|
||||||
|
```json
|
||||||
|
{"metadata_map": [
|
||||||
|
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
|
||||||
|
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
|
||||||
|
]}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_USER_PROMPT_3 = """
|
||||||
|
'{{"input_text": "{input_text}",',
|
||||||
|
'"metadata_fields": {metadata_fields}}}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
METADATA_FILTER_COMPLETION_PROMPT = """
|
||||||
|
### Job Description
|
||||||
|
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||||
|
### Task
|
||||||
|
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||||
|
### Format
|
||||||
|
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||||
|
### Constraint
|
||||||
|
DO NOT include anything other than the JSON array in your response.
|
||||||
|
### Example
|
||||||
|
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||||
|
<example>
|
||||||
|
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
|
||||||
|
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||||
|
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||||
|
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||||
|
</example>
|
||||||
|
### User Input
|
||||||
|
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||||
|
### Assistant Output
|
||||||
|
""" # noqa: E501
|
@ -59,7 +59,6 @@ class MultipleRetrievalConfig(BaseModel):
|
|||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model Config.
|
Model Config.
|
||||||
"""
|
|
||||||
|
|
||||||
provider: str
|
provider: str
|
||||||
name: str
|
name: str
|
||||||
|
@ -59,6 +59,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
embedding_model = db.Column(db.String(255), nullable=True)
|
embedding_model = db.Column(db.String(255), nullable=True)
|
||||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||||
|
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
|
||||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||||
retrieval_model = db.Column(JSONB, nullable=True)
|
retrieval_model = db.Column(JSONB, nullable=True)
|
||||||
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||||
|
@ -21,6 +21,7 @@ from core.plugin.entities.plugin import ModelProviderID
|
|||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting
|
||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
from events.document_event import document_was_deleted
|
from events.document_event import document_was_deleted
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -1131,6 +1132,408 @@ class DocumentService:
|
|||||||
return documents, batch
|
return documents, batch
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
def save_document_with_dataset_id(
|
||||||
|
dataset: Dataset,
|
||||||
|
knowledge_config: KnowledgeConfig,
|
||||||
|
account: Account | Any,
|
||||||
|
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||||
|
created_from: str = "web",
|
||||||
|
):
|
||||||
|
# check document limit
|
||||||
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
|
if features.billing.enabled:
|
||||||
|
if not knowledge_config.original_document_id:
|
||||||
|
count = 0
|
||||||
|
if knowledge_config.data_source:
|
||||||
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||||
|
count = len(upload_file_list)
|
||||||
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||||
|
for notion_info in notion_info_list: # type: ignore
|
||||||
|
count = count + len(notion_info.pages)
|
||||||
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
|
count = len(website_info.urls) # type: ignore
|
||||||
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||||
|
|
||||||
|
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||||
|
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||||
|
if count > batch_upload_limit:
|
||||||
|
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||||
|
|
||||||
|
DocumentService.check_documents_upload_quota(count, features)
|
||||||
|
|
||||||
|
# if dataset is empty, update dataset data_source_type
|
||||||
|
if not dataset.data_source_type:
|
||||||
|
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||||
|
|
||||||
|
if not dataset.indexing_technique:
|
||||||
|
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||||
|
raise ValueError("Indexing technique is invalid")
|
||||||
|
|
||||||
|
dataset.indexing_technique = knowledge_config.indexing_technique
|
||||||
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
|
model_manager = ModelManager()
|
||||||
|
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
|
||||||
|
dataset_embedding_model = knowledge_config.embedding_model
|
||||||
|
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
|
||||||
|
else:
|
||||||
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
dataset_embedding_model = embedding_model.model
|
||||||
|
dataset_embedding_model_provider = embedding_model.provider
|
||||||
|
dataset.embedding_model = dataset_embedding_model
|
||||||
|
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
dataset_embedding_model_provider, dataset_embedding_model
|
||||||
|
)
|
||||||
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
if not dataset.retrieval_model:
|
||||||
|
default_retrieval_model = {
|
||||||
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
"reranking_enable": False,
|
||||||
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset.retrieval_model = (
|
||||||
|
knowledge_config.retrieval_model.model_dump()
|
||||||
|
if knowledge_config.retrieval_model
|
||||||
|
else default_retrieval_model
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
if knowledge_config.original_document_id:
|
||||||
|
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
||||||
|
documents.append(document)
|
||||||
|
batch = document.batch
|
||||||
|
else:
|
||||||
|
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||||
|
# save process rule
|
||||||
|
if not dataset_process_rule:
|
||||||
|
process_rule = knowledge_config.process_rule
|
||||||
|
if process_rule:
|
||||||
|
if process_rule.mode in ("custom", "hierarchical"):
|
||||||
|
dataset_process_rule = DatasetProcessRule(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
mode=process_rule.mode,
|
||||||
|
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
elif process_rule.mode == "automatic":
|
||||||
|
dataset_process_rule = DatasetProcessRule(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
mode=process_rule.mode,
|
||||||
|
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warn(
|
||||||
|
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
db.session.add(dataset_process_rule)
|
||||||
|
db.session.commit()
|
||||||
|
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||||
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
|
position = DocumentService.get_documents_position(dataset.id)
|
||||||
|
document_ids = []
|
||||||
|
duplicate_document_ids = []
|
||||||
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||||
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||||
|
for file_id in upload_file_list:
|
||||||
|
file = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# raise error if file not found
|
||||||
|
if not file:
|
||||||
|
raise FileNotExistsError()
|
||||||
|
|
||||||
|
file_name = file.name
|
||||||
|
data_source_info = {
|
||||||
|
"upload_file_id": file_id,
|
||||||
|
}
|
||||||
|
# check duplicate
|
||||||
|
if knowledge_config.duplicate:
|
||||||
|
document = Document.query.filter_by(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
enabled=True,
|
||||||
|
name=file_name,
|
||||||
|
).first()
|
||||||
|
if document:
|
||||||
|
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||||
|
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
document.created_from = created_from
|
||||||
|
document.doc_form = knowledge_config.doc_form
|
||||||
|
document.doc_language = knowledge_config.doc_language
|
||||||
|
document.data_source_info = json.dumps(data_source_info)
|
||||||
|
document.batch = batch
|
||||||
|
document.indexing_status = "waiting"
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
duplicate_document_ids.append(document.id)
|
||||||
|
continue
|
||||||
|
document = DocumentService.build_document(
|
||||||
|
dataset,
|
||||||
|
dataset_process_rule.id, # type: ignore
|
||||||
|
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||||
|
knowledge_config.doc_form,
|
||||||
|
knowledge_config.doc_language,
|
||||||
|
data_source_info,
|
||||||
|
created_from,
|
||||||
|
position,
|
||||||
|
account,
|
||||||
|
file_name,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.flush()
|
||||||
|
document_ids.append(document.id)
|
||||||
|
documents.append(document)
|
||||||
|
position += 1
|
||||||
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||||
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||||
|
if not notion_info_list:
|
||||||
|
raise ValueError("No notion info list found.")
|
||||||
|
exist_page_ids = []
|
||||||
|
exist_document = {}
|
||||||
|
documents = Document.query.filter_by(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
data_source_type="notion_import",
|
||||||
|
enabled=True,
|
||||||
|
).all()
|
||||||
|
if documents:
|
||||||
|
for document in documents:
|
||||||
|
data_source_info = json.loads(document.data_source_info)
|
||||||
|
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||||
|
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||||
|
for notion_info in notion_info_list:
|
||||||
|
workspace_id = notion_info.workspace_id
|
||||||
|
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||||
|
db.and_(
|
||||||
|
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||||
|
DataSourceOauthBinding.provider == "notion",
|
||||||
|
DataSourceOauthBinding.disabled == False,
|
||||||
|
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if not data_source_binding:
|
||||||
|
raise ValueError("Data source binding not found.")
|
||||||
|
for page in notion_info.pages:
|
||||||
|
if page.page_id not in exist_page_ids:
|
||||||
|
data_source_info = {
|
||||||
|
"notion_workspace_id": workspace_id,
|
||||||
|
"notion_page_id": page.page_id,
|
||||||
|
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
||||||
|
"type": page.type,
|
||||||
|
}
|
||||||
|
# Truncate page name to 255 characters to prevent DB field length errors
|
||||||
|
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||||
|
document = DocumentService.build_document(
|
||||||
|
dataset,
|
||||||
|
dataset_process_rule.id, # type: ignore
|
||||||
|
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||||
|
knowledge_config.doc_form,
|
||||||
|
knowledge_config.doc_language,
|
||||||
|
data_source_info,
|
||||||
|
created_from,
|
||||||
|
position,
|
||||||
|
account,
|
||||||
|
truncated_page_name,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.flush()
|
||||||
|
document_ids.append(document.id)
|
||||||
|
documents.append(document)
|
||||||
|
position += 1
|
||||||
|
else:
|
||||||
|
exist_document.pop(page.page_id)
|
||||||
|
# delete not selected documents
|
||||||
|
if len(exist_document) > 0:
|
||||||
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||||
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||||
|
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||||
|
if not website_info:
|
||||||
|
raise ValueError("No website info list found.")
|
||||||
|
urls = website_info.urls
|
||||||
|
for url in urls:
|
||||||
|
data_source_info = {
|
||||||
|
"url": url,
|
||||||
|
"provider": website_info.provider,
|
||||||
|
"job_id": website_info.job_id,
|
||||||
|
"only_main_content": website_info.only_main_content,
|
||||||
|
"mode": "crawl",
|
||||||
|
}
|
||||||
|
if len(url) > 255:
|
||||||
|
document_name = url[:200] + "..."
|
||||||
|
else:
|
||||||
|
document_name = url
|
||||||
|
document = DocumentService.build_document(
|
||||||
|
dataset,
|
||||||
|
dataset_process_rule.id, # type: ignore
|
||||||
|
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||||
|
knowledge_config.doc_form,
|
||||||
|
knowledge_config.doc_language,
|
||||||
|
data_source_info,
|
||||||
|
created_from,
|
||||||
|
position,
|
||||||
|
account,
|
||||||
|
document_name,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.flush()
|
||||||
|
document_ids.append(document.id)
|
||||||
|
documents.append(document)
|
||||||
|
position += 1
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# trigger async task
|
||||||
|
if document_ids:
|
||||||
|
document_indexing_task.delay(dataset.id, document_ids)
|
||||||
|
if duplicate_document_ids:
|
||||||
|
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||||
|
|
||||||
|
return documents, batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def invoke_knowledge_index(
|
||||||
|
dataset: Dataset,
|
||||||
|
chunks: list[Any],
|
||||||
|
index_method: IndexMethod,
|
||||||
|
retrieval_setting: RetrievalSetting,
|
||||||
|
original_document_id: str | None = None,
|
||||||
|
account: Account | Any,
|
||||||
|
created_from: str = "rag-pipline",
|
||||||
|
):
|
||||||
|
|
||||||
|
if not dataset.indexing_technique:
|
||||||
|
if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||||
|
raise ValueError("Indexing technique is invalid")
|
||||||
|
|
||||||
|
dataset.indexing_technique = index_method.indexing_technique
|
||||||
|
if index_method.indexing_technique == "high_quality":
|
||||||
|
model_manager = ModelManager()
|
||||||
|
if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider:
|
||||||
|
dataset_embedding_model = index_method.embedding_setting.embedding_model
|
||||||
|
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
|
||||||
|
else:
|
||||||
|
embedding_model = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||||
|
)
|
||||||
|
dataset_embedding_model = embedding_model.model
|
||||||
|
dataset_embedding_model_provider = embedding_model.provider
|
||||||
|
dataset.embedding_model = dataset_embedding_model
|
||||||
|
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||||
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
|
dataset_embedding_model_provider, dataset_embedding_model
|
||||||
|
)
|
||||||
|
dataset.collection_binding_id = dataset_collection_binding.id
|
||||||
|
if not dataset.retrieval_model:
|
||||||
|
default_retrieval_model = {
|
||||||
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
"reranking_enable": False,
|
||||||
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset.retrieval_model = (
|
||||||
|
retrieval_setting.model_dump()
|
||||||
|
if retrieval_setting
|
||||||
|
else default_retrieval_model
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
if original_document_id:
|
||||||
|
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
|
||||||
|
documents.append(document)
|
||||||
|
batch = document.batch
|
||||||
|
else:
|
||||||
|
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||||
|
|
||||||
|
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||||
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
|
position = DocumentService.get_documents_position(dataset.id)
|
||||||
|
document_ids = []
|
||||||
|
duplicate_document_ids = []
|
||||||
|
for chunk in chunks:
|
||||||
|
file = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# raise error if file not found
|
||||||
|
if not file:
|
||||||
|
raise FileNotExistsError()
|
||||||
|
|
||||||
|
file_name = file.name
|
||||||
|
data_source_info = {
|
||||||
|
"upload_file_id": file_id,
|
||||||
|
}
|
||||||
|
# check duplicate
|
||||||
|
if knowledge_config.duplicate:
|
||||||
|
document = Document.query.filter_by(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
enabled=True,
|
||||||
|
name=file_name,
|
||||||
|
).first()
|
||||||
|
if document:
|
||||||
|
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||||
|
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
document.created_from = created_from
|
||||||
|
document.doc_form = knowledge_config.doc_form
|
||||||
|
document.doc_language = knowledge_config.doc_language
|
||||||
|
document.data_source_info = json.dumps(data_source_info)
|
||||||
|
document.batch = batch
|
||||||
|
document.indexing_status = "waiting"
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
duplicate_document_ids.append(document.id)
|
||||||
|
continue
|
||||||
|
document = DocumentService.build_document(
|
||||||
|
dataset,
|
||||||
|
dataset_process_rule.id, # type: ignore
|
||||||
|
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||||
|
knowledge_config.doc_form,
|
||||||
|
knowledge_config.doc_language,
|
||||||
|
data_source_info,
|
||||||
|
created_from,
|
||||||
|
position,
|
||||||
|
account,
|
||||||
|
file_name,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.flush()
|
||||||
|
document_ids.append(document.id)
|
||||||
|
documents.append(document)
|
||||||
|
position += 1
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# trigger async task
|
||||||
|
if document_ids:
|
||||||
|
document_indexing_task.delay(dataset.id, document_ids)
|
||||||
|
if duplicate_document_ids:
|
||||||
|
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||||
|
|
||||||
|
return documents, batch
|
||||||
|
@staticmethod
|
||||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||||
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
||||||
if count > can_upload_size:
|
if count > can_upload_size:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user