chore(api/core): Improve FileVar's type hint and imports. (#7290)

This commit is contained in:
-LAN- 2024-08-15 12:43:18 +08:00 committed by GitHub
parent 6ff7fd80a1
commit 8f16165f92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 68 additions and 59 deletions

View File

@ -1,6 +1,6 @@
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import TYPE_CHECKING, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App, def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileVar], files: list["FileVar"],
query: Optional[str] = None) -> int: query: Optional[str] = None) -> int:
""" """
Get pre calculate rest tokens Get pre calculate rest tokens
@ -126,7 +128,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileVar], files: list["FileVar"],
query: Optional[str] = None, query: Optional[str] = None,
context: Optional[str] = None, context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \ memory: Optional[TokenBufferMemory] = None) \

View File

@ -99,7 +99,7 @@ class MessageFileParser:
# return all file objs # return all file objs
return new_files return new_files
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
""" """
transform message files transform message files
@ -144,7 +144,7 @@ class MessageFileParser:
return type_file_objs return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
""" """
transform file to file obj transform file to file obj

View File

@ -1,11 +1,10 @@
import enum import enum
import json import json
import os import os
from typing import Optional from typing import TYPE_CHECKING, Optional
from core.app.app_config.entities import PromptTemplateEntity from core.app.app_config.entities import PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import AppMode from models.model import AppMode
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class ModelMode(enum.Enum): class ModelMode(enum.Enum):
COMPLETION = 'completion' COMPLETION = 'completion'
@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
prompt_template_entity: PromptTemplateEntity, prompt_template_entity: PromptTemplateEntity,
inputs: dict, inputs: dict,
query: str, query: str,
files: list[FileVar], files: list["FileVar"],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> \ model_config: ModelConfigWithCredentialsEntity) -> \
@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict, inputs: dict,
query: str, query: str,
context: Optional[str], context: Optional[str],
files: list[FileVar], files: list["FileVar"],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \ model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
inputs: dict, inputs: dict,
query: str, query: str,
context: Optional[str], context: Optional[str],
files: list[FileVar], files: list["FileVar"],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \ model_config: ModelConfigWithCredentialsEntity) \
-> tuple[list[PromptMessage], Optional[list[str]]]: -> tuple[list[PromptMessage], Optional[list[str]]]:
@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
return [self.get_last_user_message(prompt, files)], stops return [self.get_last_user_message(prompt, files)], stops
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
if files: if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)] prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files: for file in files:

View File

@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileVar
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolDescription, ToolDescription,
ToolIdentity, ToolIdentity,
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.tool_parameter_converter import ToolParameterConverter
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class Tool(BaseModel, ABC): class Tool(BaseModel, ABC):
identity: Optional[ToolIdentity] = None identity: Optional[ToolIdentity] = None
@ -290,7 +292,7 @@ class Tool(BaseModel, ABC):
message=image, message=image,
save_as=save_as) save_as=save_as)
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
message='', message='',
meta={ meta={

View File

@ -1,7 +1,7 @@
import logging import logging
from mimetypes import guess_extension from mimetypes import guess_extension
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
@ -81,7 +81,7 @@ class ToolFileMessageTransformer:
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
)) ))
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var: FileVar = message.meta.get('file_var') file_var = message.meta.get('file_var')
if file_var: if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE: if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension) url = cls.get_tool_file_url(file_var.related_id, file_var.extension)

View File

@ -1,14 +1,13 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from typing import Optional, cast from typing import TYPE_CHECKING, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
@ -39,6 +38,10 @@ from models.model import Conversation
from models.provider import Provider, ProviderType from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class LLMNode(BaseNode): class LLMNode(BaseNode):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
@ -71,7 +74,7 @@ class LLMNode(BaseNode):
node_inputs = {} node_inputs = {}
# fetch files # fetch files
files: list[FileVar] = self._fetch_files(node_data, variable_pool) files = self._fetch_files(node_data, variable_pool)
if files: if files:
node_inputs['#files#'] = [file.to_dict() for file in files] node_inputs['#files#'] = [file.to_dict() for file in files]
@ -322,7 +325,7 @@ class LLMNode(BaseNode):
return inputs return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
""" """
Fetch files Fetch files
:param node_data: node data :param node_data: node data
@ -521,7 +524,7 @@ class LLMNode(BaseNode):
query: Optional[str], query: Optional[str],
query_prompt_template: Optional[str], query_prompt_template: Optional[str],
inputs: dict[str, str], inputs: dict[str, str],
files: list[FileVar], files: list["FileVar"],
context: Optional[str], context: Optional[str],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) \ model_config: ModelConfigWithCredentialsEntity) \