mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 11:28:57 +08:00
chore(api/core): Improve FileVar's type hint and imports. (#7290)
This commit is contained in:
parent
6ff7fd80a1
commit
8f16165f92
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
|||||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
|||||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||||
from models.model import App, AppMode, Message, MessageAnnotation
|
from models.model import App, AppMode, Message, MessageAnnotation
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
class AppRunner:
|
class AppRunner:
|
||||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
query: Optional[str] = None) -> int:
|
query: Optional[str] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Get pre calculate rest tokens
|
Get pre calculate rest tokens
|
||||||
@ -126,7 +128,7 @@ class AppRunner:
|
|||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
memory: Optional[TokenBufferMemory] = None) \
|
memory: Optional[TokenBufferMemory] = None) \
|
||||||
@ -366,7 +368,7 @@ class AppRunner:
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
trace_manager=app_generate_entity.trace_manager
|
trace_manager=app_generate_entity.trace_manager
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
prompt_messages: list[PromptMessage]) -> bool:
|
prompt_messages: list[PromptMessage]) -> bool:
|
||||||
@ -418,7 +420,7 @@ class AppRunner:
|
|||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query
|
query=query
|
||||||
)
|
)
|
||||||
|
|
||||||
def query_app_annotations_to_reply(self, app_record: App,
|
def query_app_annotations_to_reply(self, app_record: App,
|
||||||
message: Message,
|
message: Message,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||||||
node_id: str
|
node_id: str
|
||||||
inputs: dict
|
inputs: dict
|
||||||
|
|
||||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||||
|
@ -99,7 +99,7 @@ class MessageFileParser:
|
|||||||
# return all file objs
|
# return all file objs
|
||||||
return new_files
|
return new_files
|
||||||
|
|
||||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
|
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||||
"""
|
"""
|
||||||
transform message files
|
transform message files
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ class MessageFileParser:
|
|||||||
|
|
||||||
return type_file_objs
|
return type_file_objs
|
||||||
|
|
||||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
|
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||||
"""
|
"""
|
||||||
transform file to file obj
|
transform file to file obj
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import PromptTemplateEntity
|
from core.app.app_config.entities import PromptTemplateEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
|
|||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
class ModelMode(enum.Enum):
|
class ModelMode(enum.Enum):
|
||||||
COMPLETION = 'completion'
|
COMPLETION = 'completion'
|
||||||
@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||||
@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) \
|
model_config: ModelConfigWithCredentialsEntity) \
|
||||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) \
|
model_config: ModelConfigWithCredentialsEntity) \
|
||||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
|
|
||||||
return [self.get_last_user_message(prompt, files)], stops
|
return [self.get_last_user_message(prompt, files)], stops
|
||||||
|
|
||||||
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
|
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||||
for file in files:
|
for file in files:
|
||||||
|
@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
from pydantic_core.core_schema import ValidationInfo
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolDescription,
|
ToolDescription,
|
||||||
ToolIdentity,
|
ToolIdentity,
|
||||||
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
|
|||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel, ABC):
|
class Tool(BaseModel, ABC):
|
||||||
identity: Optional[ToolIdentity] = None
|
identity: Optional[ToolIdentity] = None
|
||||||
@ -76,7 +78,7 @@ class Tool(BaseModel, ABC):
|
|||||||
description=self.description.model_copy() if self.description else None,
|
description=self.description.model_copy() if self.description else None,
|
||||||
runtime=Tool.Runtime(**runtime),
|
runtime=Tool.Runtime(**runtime),
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tool_provider_type(self) -> ToolProviderType:
|
def tool_provider_type(self) -> ToolProviderType:
|
||||||
"""
|
"""
|
||||||
@ -84,7 +86,7 @@ class Tool(BaseModel, ABC):
|
|||||||
|
|
||||||
:return: the tool provider type
|
:return: the tool provider type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||||
"""
|
"""
|
||||||
load variables from database
|
load variables from database
|
||||||
@ -99,7 +101,7 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.variables:
|
if not self.variables:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.variables.set_file(self.identity.name, variable_name, image_key)
|
self.variables.set_file(self.identity.name, variable_name, image_key)
|
||||||
|
|
||||||
def set_text_variable(self, variable_name: str, text: str) -> None:
|
def set_text_variable(self, variable_name: str, text: str) -> None:
|
||||||
@ -108,9 +110,9 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.variables:
|
if not self.variables:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.variables.set_text(self.identity.name, variable_name, text)
|
self.variables.set_text(self.identity.name, variable_name, text)
|
||||||
|
|
||||||
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
|
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
|
||||||
"""
|
"""
|
||||||
get a variable
|
get a variable
|
||||||
@ -120,14 +122,14 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.variables:
|
if not self.variables:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(name, Enum):
|
if isinstance(name, Enum):
|
||||||
name = name.value
|
name = name.value
|
||||||
|
|
||||||
for variable in self.variables.pool:
|
for variable in self.variables.pool:
|
||||||
if variable.name == name:
|
if variable.name == name:
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
|
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
|
||||||
@ -138,9 +140,9 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.variables:
|
if not self.variables:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.get_variable(self.VARIABLE_KEY.IMAGE)
|
return self.get_variable(self.VARIABLE_KEY.IMAGE)
|
||||||
|
|
||||||
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
|
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
|
||||||
"""
|
"""
|
||||||
get a variable file
|
get a variable file
|
||||||
@ -151,7 +153,7 @@ class Tool(BaseModel, ABC):
|
|||||||
variable = self.get_variable(name)
|
variable = self.get_variable(name)
|
||||||
if not variable:
|
if not variable:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(variable, ToolRuntimeImageVariable):
|
if not isinstance(variable, ToolRuntimeImageVariable):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -160,9 +162,9 @@ class Tool(BaseModel, ABC):
|
|||||||
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
|
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
|
||||||
if not file_binary:
|
if not file_binary:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return file_binary[0]
|
return file_binary[0]
|
||||||
|
|
||||||
def list_variables(self) -> list[ToolRuntimeVariable]:
|
def list_variables(self) -> list[ToolRuntimeVariable]:
|
||||||
"""
|
"""
|
||||||
list all variables
|
list all variables
|
||||||
@ -171,9 +173,9 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.variables:
|
if not self.variables:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return self.variables.pool
|
return self.variables.pool
|
||||||
|
|
||||||
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
|
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
|
||||||
"""
|
"""
|
||||||
list all image variables
|
list all image variables
|
||||||
@ -182,9 +184,9 @@ class Tool(BaseModel, ABC):
|
|||||||
"""
|
"""
|
||||||
if not self.variables:
|
if not self.variables:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for variable in self.variables.pool:
|
for variable in self.variables.pool:
|
||||||
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
|
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
|
||||||
result.append(variable)
|
result.append(variable)
|
||||||
@ -225,7 +227,7 @@ class Tool(BaseModel, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
validate the credentials
|
validate the credentials
|
||||||
@ -244,7 +246,7 @@ class Tool(BaseModel, ABC):
|
|||||||
:return: the runtime parameters
|
:return: the runtime parameters
|
||||||
"""
|
"""
|
||||||
return self.parameters or []
|
return self.parameters or []
|
||||||
|
|
||||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||||
"""
|
"""
|
||||||
get all runtime parameters
|
get all runtime parameters
|
||||||
@ -278,7 +280,7 @@ class Tool(BaseModel, ABC):
|
|||||||
parameters.append(parameter)
|
parameters.append(parameter)
|
||||||
|
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
create an image message
|
create an image message
|
||||||
@ -286,18 +288,18 @@ class Tool(BaseModel, ABC):
|
|||||||
:param image: the url of the image
|
:param image: the url of the image
|
||||||
:return: the image message
|
:return: the image message
|
||||||
"""
|
"""
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
||||||
message=image,
|
message=image,
|
||||||
save_as=save_as)
|
save_as=save_as)
|
||||||
|
|
||||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||||
message='',
|
message='',
|
||||||
meta={
|
meta={
|
||||||
'file_var': file_var
|
'file_var': file_var
|
||||||
},
|
},
|
||||||
save_as='')
|
save_as='')
|
||||||
|
|
||||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
create a link message
|
create a link message
|
||||||
@ -305,10 +307,10 @@ class Tool(BaseModel, ABC):
|
|||||||
:param link: the url of the link
|
:param link: the url of the link
|
||||||
:return: the link message
|
:return: the link message
|
||||||
"""
|
"""
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
||||||
message=link,
|
message=link,
|
||||||
save_as=save_as)
|
save_as=save_as)
|
||||||
|
|
||||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
create a text message
|
create a text message
|
||||||
@ -321,7 +323,7 @@ class Tool(BaseModel, ABC):
|
|||||||
message=text,
|
message=text,
|
||||||
save_as=save_as
|
save_as=save_as
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
create a blob message
|
create a blob message
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
from core.file.file_obj import FileTransferMethod, FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
@ -27,12 +27,12 @@ class ToolFileMessageTransformer:
|
|||||||
# try to download image
|
# try to download image
|
||||||
try:
|
try:
|
||||||
file = ToolFileManager.create_file_by_url(
|
file = ToolFileManager.create_file_by_url(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
file_url=message.message
|
file_url=message.message
|
||||||
)
|
)
|
||||||
|
|
||||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||||
|
|
||||||
result.append(ToolInvokeMessage(
|
result.append(ToolInvokeMessage(
|
||||||
@ -55,14 +55,14 @@ class ToolFileMessageTransformer:
|
|||||||
# if message is str, encode it to bytes
|
# if message is str, encode it to bytes
|
||||||
if isinstance(message.message, str):
|
if isinstance(message.message, str):
|
||||||
message.message = message.message.encode('utf-8')
|
message.message = message.message.encode('utf-8')
|
||||||
|
|
||||||
file = ToolFileManager.create_file_by_raw(
|
file = ToolFileManager.create_file_by_raw(
|
||||||
user_id=user_id, tenant_id=tenant_id,
|
user_id=user_id, tenant_id=tenant_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
file_binary=message.message,
|
file_binary=message.message,
|
||||||
mimetype=mimetype
|
mimetype=mimetype
|
||||||
)
|
)
|
||||||
|
|
||||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||||
|
|
||||||
# check if file is image
|
# check if file is image
|
||||||
@ -81,7 +81,7 @@ class ToolFileMessageTransformer:
|
|||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
))
|
))
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||||
file_var: FileVar = message.meta.get('file_var')
|
file_var = message.meta.get('file_var')
|
||||||
if file_var:
|
if file_var:
|
||||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||||
@ -103,7 +103,7 @@ class ToolFileMessageTransformer:
|
|||||||
result.append(message)
|
result.append(message)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, cast
|
from typing import TYPE_CHECKING, Optional, cast
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
@ -39,6 +38,10 @@ from models.model import Conversation
|
|||||||
from models.provider import Provider, ProviderType
|
from models.provider import Provider, ProviderType
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(BaseNode):
|
class LLMNode(BaseNode):
|
||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
@ -71,7 +74,7 @@ class LLMNode(BaseNode):
|
|||||||
node_inputs = {}
|
node_inputs = {}
|
||||||
|
|
||||||
# fetch files
|
# fetch files
|
||||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
files = self._fetch_files(node_data, variable_pool)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||||
@ -322,7 +325,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
|
||||||
"""
|
"""
|
||||||
Fetch files
|
Fetch files
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
@ -521,7 +524,7 @@ class LLMNode(BaseNode):
|
|||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
query_prompt_template: Optional[str],
|
query_prompt_template: Optional[str],
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list[FileVar],
|
files: list["FileVar"],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity) \
|
model_config: ModelConfigWithCredentialsEntity) \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user