feat: enhance gemini models (#11497)

This commit is contained in:
非法操作 2024-12-17 12:05:13 +08:00 committed by GitHub
parent 56cfdce453
commit 74fdc16bd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 138 additions and 113 deletions

View File

@ -50,12 +50,12 @@ def to_prompt_message_content(
else:
data = _to_base64_data_string(f)
return ImagePromptMessageContent(data=data, detail=image_detail_config)
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO:
encoded_string = _get_encoded_string(f)
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
@ -65,14 +65,8 @@ def to_prompt_message_content(
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")

View File

@ -101,13 +101,14 @@ class ImagePromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
format: str = Field(..., description="Document format")
class PromptMessage(ABC, BaseModel):

View File

@ -526,17 +526,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
f"Unsupported document type {mime_type}, " "only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
"media_type": mime_type,
"data": base64_data,
},
}
sub_messages.append(sub_message_dict)

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@ -1,29 +1,30 @@
import base64
import io
import json
import os
import tempfile
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Optional, Union
import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
from extensions.ext_redis import redis_client
class GoogleLargeLanguageModel(LargeLanguageModel):
@ -201,29 +188,17 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
if stop:
config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
history = []
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
response = google_model.generate_content(
contents=history,
@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return message_text
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.data.startswith("data:"):
metadata, base64_data = message_content.data.split(",", 1)
file_content = base64.b64decode(base64_data)
mime_type = metadata.split(";", 1)[0].split(":")[1]
temp_file.write(file_content)
else:
# only ImagePromptMessageContent and VideoPromptMessageContent has url
try:
response = requests.get(message_content.data)
response.raise_for_status()
if message_content.type is ImagePromptMessageContent:
prefix = "image/"
elif message_content.type is VideoPromptMessageContent:
prefix = "video/"
mime_type = prefix + message_content.format
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
temp_file.flush()
try:
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
return file
finally:
os.unlink(temp_file.name)
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))
return glm_content
elif isinstance(message, AssistantPromptMessage):

View File

@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, AudioPromptMessageContent):
data_split = message_content.data.split(";base64,")
base64_data = data_split[1]
sub_message_dict = {
"type": "input_audio",
"input_audio": {
"data": message_content.data,
"data": base64_data,
"format": message_content.format,
},
}

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import google.generativeai.types.generation_types as generation_config_types
import pytest
@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel
from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = ""
from extensions import ext_redis
class MockGoogleResponseClass:
@ -57,11 +57,6 @@ class MockGoogleClass:
stream: bool = False,
**kwargs,
) -> GenerateContentResponse:
global current_api_key
if len(current_api_key) < 16:
raise Exception("Invalid API key")
if stream:
return MockGoogleClass.generate_content_stream()
@ -75,33 +70,29 @@ class MockGoogleClass:
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
def mock_configure(api_key: str):
if len(api_key) < 16:
raise Exception("Invalid API key")
# Attempt to configure using defaults.
if not self.client_config:
configure()
client_options = self.client_config.get("client_options", None)
if client_options:
current_api_key = client_options.api_key
class MockFileState:
def __init__(self):
self.name = "FINISHED"
def nop(self, *args, **kwargs):
pass
original_init = cls.__init__
cls.__init__ = nop
client: glm.GenerativeServiceClient = cls(**self.client_config)
cls.__init__ = original_init
class MockGoogleFile:
def __init__(self, name: str = "mock_file_name"):
self.name = name
self.state = MockFileState()
if not self.default_metadata:
return client
def mock_get_file(name: str) -> MockGoogleFile:
return MockGoogleFile(name)
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
return MockGoogleFile()
@pytest.fixture
@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
monkeypatch.setattr("google.generativeai.configure", mock_configure)
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
yield
monkeypatch.undo()
@pytest.fixture
def setup_mock_redis() -> None:
ext_redis.redis_client.get = MagicMock(return_value=None)
ext_redis.redis_client.setex = MagicMock(return_value=None)
ext_redis.redis_client.exists = MagicMock(return_value=True)

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_invoke_chat_model_with_vision(setup_google_mock):
def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
model = GoogleLargeLanguageModel()
result = model.invoke(
@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
model = GoogleLargeLanguageModel()
result = model.invoke(

View File

@ -326,6 +326,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
@ -395,6 +396,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)