mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 06:19:04 +08:00
feat: enhance gemini models (#11497)
This commit is contained in:
parent
56cfdce453
commit
74fdc16bd1
@ -50,12 +50,12 @@ def to_prompt_message_content(
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _get_encoded_string(f)
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
@ -65,14 +65,8 @@ def to_prompt_message_content(
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
data = _to_base64_data_string(f)
|
||||
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
@ -101,13 +101,14 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
format: str = Field("jpg", description="Image format")
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
format: str = Field(..., description="Document format")
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
|
@ -526,17 +526,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, DocumentPromptMessageContent):
|
||||
if message_content.mime_type != "application/pdf":
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
if mime_type != "application/pdf":
|
||||
raise ValueError(
|
||||
f"Unsupported document type {message_content.mime_type}, "
|
||||
"only support application/pdf"
|
||||
f"Unsupported document type {mime_type}, " "only support application/pdf"
|
||||
)
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": message_content.encode_format,
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
"media_type": mime_type,
|
||||
"data": base64_data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
@ -8,6 +8,8 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
@ -7,6 +7,9 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
@ -1,29 +1,30 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Optional, Union
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.client import _ClientManager
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
||||
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
||||
"application/pdf",
|
||||
"application/x-javascript",
|
||||
"text/javascript",
|
||||
"application/x-python",
|
||||
"text/x-python",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/md",
|
||||
"text/csv",
|
||||
"text/xml",
|
||||
"text/rtf",
|
||||
]
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
@ -201,29 +188,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-pro-vision":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
# Create a new ClientManager with tenant's API key
|
||||
new_client_manager = _ClientManager()
|
||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
||||
key = f"{message_content.type.value}:{hash(message_content.data)}"
|
||||
if redis_client.exists(key):
|
||||
try:
|
||||
return genai.get_file(redis_client.get(key).decode())
|
||||
except:
|
||||
pass
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = message_content.data.split(",", 1)
|
||||
file_content = base64.b64decode(base64_data)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
temp_file.write(file_content)
|
||||
else:
|
||||
# only ImagePromptMessageContent and VideoPromptMessageContent has url
|
||||
try:
|
||||
response = requests.get(message_content.data)
|
||||
response.raise_for_status()
|
||||
if message_content.type is ImagePromptMessageContent:
|
||||
prefix = "image/"
|
||||
elif message_content.type is VideoPromptMessageContent:
|
||||
prefix = "video/"
|
||||
mime_type = prefix + message_content.format
|
||||
temp_file.write(response.content)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
|
||||
temp_file.flush()
|
||||
try:
|
||||
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
|
||||
while file.state.name == "PROCESSING":
|
||||
time.sleep(5)
|
||||
file = genai.get_file(file.name)
|
||||
# google will delete your upload files in 2 days.
|
||||
redis_client.setex(key, 47 * 60 * 60, file.name)
|
||||
return file
|
||||
finally:
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
||||
message_content = cast(DocumentPromptMessageContent, c)
|
||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
||||
glm_content["parts"].append(blob)
|
||||
else:
|
||||
glm_content["parts"].append(self._upload_file_content_to_google(c))
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, AudioPromptMessageContent):
|
||||
data_split = message_content.data.split(";base64,")
|
||||
base64_data = data_split[1]
|
||||
sub_message_dict = {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": message_content.data,
|
||||
"data": base64_data,
|
||||
"format": message_content.format,
|
||||
},
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import google.generativeai.types.generation_types as generation_config_types
|
||||
import pytest
|
||||
@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.client import _ClientManager, configure
|
||||
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
current_api_key = ""
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
@ -57,11 +57,6 @@ class MockGoogleClass:
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
@ -75,33 +70,29 @@ class MockGoogleClass:
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
if name.endswith("_async"):
|
||||
name = name.split("_")[0]
|
||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
||||
else:
|
||||
cls = getattr(glm, name.title() + "ServiceClient")
|
||||
def mock_configure(api_key: str):
|
||||
if len(api_key) < 16:
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
# Attempt to configure using defaults.
|
||||
if not self.client_config:
|
||||
configure()
|
||||
|
||||
client_options = self.client_config.get("client_options", None)
|
||||
if client_options:
|
||||
current_api_key = client_options.api_key
|
||||
class MockFileState:
|
||||
def __init__(self):
|
||||
self.name = "FINISHED"
|
||||
|
||||
def nop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
original_init = cls.__init__
|
||||
cls.__init__ = nop
|
||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
||||
cls.__init__ = original_init
|
||||
class MockGoogleFile:
|
||||
def __init__(self, name: str = "mock_file_name"):
|
||||
self.name = name
|
||||
self.state = MockFileState()
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
def mock_get_file(name: str) -> MockGoogleFile:
|
||||
return MockGoogleFile(name)
|
||||
|
||||
|
||||
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
|
||||
return MockGoogleFile()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
||||
monkeypatch.setattr("google.generativeai.configure", mock_configure)
|
||||
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
|
||||
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis() -> None:
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.setex = MagicMock(return_value=None)
|
||||
ext_redis.redis_client.exists = MagicMock(return_value=True)
|
||||
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_invoke_chat_model_with_vision(setup_google_mock):
|
||||
def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
||||
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
|
||||
model = GoogleLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
|
@ -326,6 +326,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
extension=".jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
@ -395,6 +396,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
extension=".jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user