diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index 4b364d15c6..d8d7b3e860 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -82,7 +82,7 @@ class BasePluginManager: Make a stream request to the plugin daemon inner API """ response = self._request(method, path, headers, data, params, files, stream=True) - for line in response.iter_lines(): + for line in response.iter_lines(chunk_size=1024 * 8): line = line.decode("utf-8").strip() if line.startswith("data:"): line = line[5:].strip() diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/manager/tool.py index 4c3abd3acf..7592f867e1 100644 --- a/api/core/plugin/manager/tool.py +++ b/api/core/plugin/manager/tool.py @@ -110,7 +110,62 @@ class PluginToolManager(BasePluginManager): "Content-Type": "application/json", }, ) - return response + + class FileChunk: + """ + Only used for internal processing. + """ + + bytes_written: int + total_length: int + data: bytearray + + def __init__(self, total_length: int): + self.bytes_written = 0 + self.total_length = total_length + self.data = bytearray(total_length) + + files: dict[str, FileChunk] = {} + for resp in response: + if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: + assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage) + # Get blob chunk information + chunk_id = resp.message.id + total_length = resp.message.total_length + blob_data = resp.message.blob + is_end = resp.message.end + + # Initialize buffer for this file if it doesn't exist + if chunk_id not in files: + files[chunk_id] = FileChunk(total_length) + + # If this is the final chunk, yield a complete blob message + if is_end: + yield ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data), + meta=resp.meta, + ) + else: + # Check if file is too large (30MB limit) + if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024: + # Delete the file if it's too large + del files[chunk_id] + # Skip yielding this message + raise ValueError("File is too large which reached the limit of 30MB") + + # Check if single chunk is too large (8KB limit) + if len(blob_data) > 8192: + # Skip yielding this message + raise ValueError("File chunk is too large which reached the limit of 8KB") + + # Append the blob data to the buffer + files[chunk_id].data[ + files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data) + ] = blob_data + files[chunk_id].bytes_written += len(blob_data) + else: + yield resp def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d756763137..37375f4a71 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -120,6 +120,13 @@ class ToolInvokeMessage(BaseModel): class BlobMessage(BaseModel): blob: bytes + class BlobChunkMessage(BaseModel): + id: str = Field(..., description="The id of the blob") + sequence: int = Field(..., description="The sequence of the chunk") + total_length: int = Field(..., description="The total length of the blob") + blob: bytes = Field(..., description="The blob data of the chunk") + end: bool = Field(..., description="Whether the chunk is the last chunk") + class FileMessage(BaseModel): pass @@ -180,12 +187,15 @@ class ToolInvokeMessage(BaseModel): VARIABLE = "variable" FILE = "file" LOG = "log" + BLOB_CHUNK = "blob_chunk" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ - message: JsonMessage | TextMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + message: ( + JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + ) meta: dict[str, Any] | None = None @field_validator("message", mode="before")