feat: support LLM understand video (#9828)

This commit is contained in:
非法操作 2024-11-08 13:22:52 +08:00 committed by GitHub
parent c9f785e00f
commit 033ab5490b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 69 additions and 20 deletions

View File

@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration # Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64 MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512 PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024

View File

@ -634,12 +634,17 @@ class IndexingConfig(BaseSettings):
) )
class ImageFormatConfig(BaseSettings): class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64", default="base64",
) )
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings): class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field( CELERY_BEAT_SCHEDULER_TIME: int = Field(
@ -742,7 +747,7 @@ class FeatureConfig(
FileAccessConfig, FileAccessConfig,
FileUploadConfig, FileUploadConfig,
HttpConfig, HttpConfig,
ImageFormatConfig, VisionFormatConfig,
InnerAPIConfig, InnerAPIConfig,
IndexingConfig, IndexingConfig,
LoggingConfig, LoggingConfig,

View File

@ -3,7 +3,7 @@ import base64
from configs import dify_config from configs import dify_config
from core.file import file_repository from core.file import file_repository
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /):
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _: case _:
raise ValueError(f"file type {f.type} is not supported") raise ValueError(f"file type {f.type} is not supported")
@ -112,7 +118,7 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /): def _get_encoded_string(f: File, /):
match f.transfer_method: match f.transfer_method:
case FileTransferMethod.REMOTE_URL: case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content
encoded_string = base64.b64encode(content).decode("utf-8") encoded_string = base64.b64encode(content).decode("utf-8")
@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /):
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
return _to_base64_data_string(f) return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO: case FileType.AUDIO:
return _get_encoded_string(f) return _get_encoded_string(f)
case _: case _:

View File

@ -12,11 +12,13 @@ from .message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
VideoPromptMessageContent,
) )
from .model_entities import ModelPropertyKey from .model_entities import ModelPropertyKey
__all__ = [ __all__ = [
"ImagePromptMessageContent", "ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage", "PromptMessage",
"PromptMessageRole", "PromptMessageRole",
"LLMUsage", "LLMUsage",

View File

@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT type: PromptMessageContentType = PromptMessageContentType.TEXT
class VideoPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent): class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data") data: str = Field(..., description="Base64 encoded audio data")

View File

@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
AIModelEntity, AIModelEntity,
@ -431,6 +432,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_message_dict = {"image": image_url} sub_message_dict = {"image": image_url}
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)
# resort sub_messages to ensure text is always at last # resort sub_messages to ensure text is always at last
sub_messages = sorted(sub_messages, key=lambda x: "text" in x) sub_messages = sorted(sub_messages, key=lambda x: "text" in x)

View File

@ -313,21 +313,35 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
return params return params
def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
if isinstance(prompt_message, str): if isinstance(prompt_message, list):
sub_messages = []
for item in prompt_message:
if item.type == PromptMessageContentType.IMAGE:
sub_messages.append(
{
"type": "image_url",
"image_url": {"url": self._remove_base64_header(item.data)},
}
)
elif item.type == PromptMessageContentType.VIDEO:
sub_messages.append(
{
"type": "video_url",
"video_url": {"url": self._remove_base64_header(item.data)},
}
)
else:
sub_messages.append({"type": "text", "text": item.data})
return sub_messages
else:
return [{"type": "text", "text": prompt_message}] return [{"type": "text", "text": prompt_message}]
return [ def _remove_base64_header(self, file_content: str) -> str:
{"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}} if file_content.startswith("data:"):
if item.type == PromptMessageContentType.IMAGE data_split = file_content.split(";base64,")
else {"type": "text", "text": item.data} return data_split[1]
for item in prompt_message
]
def _remove_image_header(self, image: str) -> str: return file_content
if image.startswith("data:image"):
return image.split(",")[1]
return image
def _handle_generate_response( def _handle_generate_response(
self, self,

View File

@ -14,6 +14,7 @@ from core.model_runtime.entities import (
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -560,7 +561,9 @@ class LLMNode(BaseNode[LLMNodeData]):
# cuz vision detail is related to the configuration from FileUpload feature. # cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail content_item.detail = vision_detail
prompt_message_content.append(content_item) prompt_message_content.append(content_item)
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent): elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
):
prompt_message_content.append(content_item) prompt_message_content.append(content_item)
if len(prompt_message_content) > 1: if len(prompt_message_content) > 1:

View File

@ -468,8 +468,8 @@ const Configuration: FC = () => {
transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
}, },
enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled),
allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video],
allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`),
allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3,
fileUploadConfig: fileUploadConfigResponse, fileUploadConfig: fileUploadConfigResponse,