diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index f392d6a2dd..4b5f4716ed 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -213,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse): created_by: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -298,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse): execution_metadata: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parent_parallel_id: Optional[str] = None diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 49818b69d2..bbd9531b19 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -150,7 +150,7 @@ class AdvancedPromptTransform(PromptTransform): for k, v in inputs.items(): if k.startswith("#"): vp.add(k[1:-1].split("."), v) - raw_prompt.replace("{{#context#}}", context or "") + raw_prompt = raw_prompt.replace("{{#context#}}", context or "") prompt = vp.convert_template(raw_prompt).text else: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 21f05f2a18..d61d2d1de0 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -359,7 +359,7 @@ class LLMNode(BaseNode[LLMNodeData]): return [] raise ValueError(f"Invalid variable type: {type(variable)}") - def _fetch_context(self, node_data: LLMNodeData) -> Generator[RunEvent, None, None]: + def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index bb7278b70e..82ba34a7fc 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -2,6 +2,9 @@ from collections.abc import Mapping, Sequence from os import path from typing import Any +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file.models import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -14,6 +17,8 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser from enums import NodeType +from extensions.ext_database import db +from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus @@ -167,45 +172,59 @@ class ToolNode(BaseNode[ToolNodeData]): result = [] for response in tool_response: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: - url = response.message - ext = path.splitext(url)[1] - mimetype = response.meta.get("mime_type", "image/jpeg") - tool_file_id = response.save_as or url.split("/")[-1] + url = str(response.message) if response.message else None + ext = path.splitext(url)[1] if url else ".bin" + tool_file_id = response.save_as or str(url).split("/")[-1].split(".")[0] transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - # get tool file id - tool_file_id = str(url).split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") + result.append( File( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, remote_url=url, - related_id=tool_file_id, - filename=tool_file_id, + related_id=tool_file.id, + filename=tool_file.name, extension=ext, - mime_type=mimetype, + mime_type=tool_file.mimetype, + size=tool_file.size, ) ) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id tool_file_id = str(response.message).split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") result.append( File( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=response.save_as, + related_id=tool_file.id, + filename=tool_file.name, extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get("mime_type", "application/octet-stream"), + mime_type=tool_file.mimetype, + size=tool_file.size, ) ) elif response.type == ToolInvokeMessage.MessageType.LINK: url = str(response.message) transfer_method = FileTransferMethod.TOOL_FILE - mimetype = response.meta.get("mime_type", "application/octet-stream") tool_file_id = url.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") if "." in url: extension = "." + url.split("/")[-1].split(".")[1] else: @@ -215,10 +234,11 @@ class ToolNode(BaseNode[ToolNodeData]): type=FileType(response.save_as), transfer_method=transfer_method, remote_url=url, - filename=tool_file_id, - related_id=tool_file_id, + filename=tool_file.name, + related_id=tool_file.id, extension=extension, - mime_type=mimetype, + mime_type=tool_file.mimetype, + size=tool_file.size, ) result.append(file)