feat: enhance gemini models (#11497)

This commit is contained in:
非法操作 2024-12-17 12:05:13 +08:00 committed by GitHub
parent 56cfdce453
commit 74fdc16bd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 138 additions and 113 deletions

View File

@ -50,12 +50,12 @@ def to_prompt_message_content(
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
return ImagePromptMessageContent(data=data, detail=image_detail_config) return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO: case FileType.AUDIO:
encoded_string = _get_encoded_string(f) data = _to_base64_data_string(f)
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO: case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f) data = _to_url(f)
@ -65,14 +65,8 @@ def to_prompt_message_content(
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT: case FileType.DOCUMENT:
data = _get_encoded_string(f) data = _to_base64_data_string(f)
if f.mime_type is None: return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _: case _:
raise ValueError(f"file type {f.type} is not supported") raise ValueError(f"file type {f.type} is not supported")

View File

@ -101,13 +101,14 @@ class ImagePromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.IMAGE type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")
class DocumentPromptMessageContent(PromptMessageContent): class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"] encode_format: Literal["base64"]
mime_type: str
data: str data: str
format: str = Field(..., description="Document format")
class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):

View File

@ -526,17 +526,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent): elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf": data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
raise ValueError( raise ValueError(
f"Unsupported document type {message_content.mime_type}, " f"Unsupported document type {mime_type}, " "only support application/pdf"
"only support application/pdf"
) )
sub_message_dict = { sub_message_dict = {
"type": "document", "type": "document",
"source": { "source": {
"type": message_content.encode_format, "type": message_content.encode_format,
"media_type": message_content.mime_type, "media_type": mime_type,
"data": message_content.data, "data": base64_data,
}, },
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

View File

@ -8,6 +8,8 @@ features:
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document - document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 32767 context_size: 32767

View File

@ -7,6 +7,9 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
- video
- audio
model_properties: model_properties:
mode: chat mode: chat
context_size: 32767 context_size: 32767

View File

@ -1,29 +1,30 @@
import base64 import base64
import io
import json import json
import os
import tempfile
import time
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
import google.generativeai as genai import google.generativeai as genai
import requests import requests
from google.api_core import exceptions from google.api_core import exceptions
from google.generativeai.client import _ClientManager from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types.content_types import to_part from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageContentType, PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
VideoPromptMessageContent,
) )
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_redis import redis_client
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
@ -201,16 +188,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
if stop: if stop:
config_kwargs["stop_sequences"] = stop config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model) google_model = genai.GenerativeModel(model_name=model)
history = [] history = []
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg) content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]: if history and history[-1]["role"] == content["role"]:
@ -218,13 +200,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
else: else:
history.append(content) history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
response = google_model.generate_content( response = google_model.generate_content(
contents=history, contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs), generation_config=genai.types.GenerationConfig(**config_kwargs),
@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = message.content content = message.content
if isinstance(content, list): if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}" message_text = f"{human_prompt} {content}"
@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return message_text return message_text
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.data.startswith("data:"):
metadata, base64_data = message_content.data.split(",", 1)
file_content = base64.b64decode(base64_data)
mime_type = metadata.split(";", 1)[0].split(":")[1]
temp_file.write(file_content)
else:
# only ImagePromptMessageContent and VideoPromptMessageContent has url
try:
response = requests.get(message_content.data)
response.raise_for_status()
if message_content.type is ImagePromptMessageContent:
prefix = "image/"
elif message_content.type is VideoPromptMessageContent:
prefix = "video/"
mime_type = prefix + message_content.format
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
temp_file.flush()
try:
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
return file
finally:
os.unlink(temp_file.name)
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
""" """
Format a single message into glm.Content for Google API Format a single message into glm.Content for Google API
@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content: for c in message.content:
if c.type == PromptMessageContentType.TEXT: if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data)) glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else: else:
# fetch image data from url glm_content["parts"].append(self._upload_file_content_to_google(c))
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
return glm_content return glm_content
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):

View File

@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif isinstance(message_content, AudioPromptMessageContent): elif isinstance(message_content, AudioPromptMessageContent):
data_split = message_content.data.split(";base64,")
base64_data = data_split[1]
sub_message_dict = { sub_message_dict = {
"type": "input_audio", "type": "input_audio",
"input_audio": { "input_audio": {
"data": message_content.data, "data": base64_data,
"format": message_content.format, "format": message_content.format,
}, },
} }

View File

@ -1,4 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from unittest.mock import MagicMock
import google.generativeai.types.generation_types as generation_config_types import google.generativeai.types.generation_types as generation_config_types
import pytest import pytest
@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
from google.ai import generativelanguage as glm from google.ai import generativelanguage as glm
from google.ai.generativelanguage_v1beta.types import content as gag_content from google.ai.generativelanguage_v1beta.types import content as gag_content
from google.generativeai import GenerativeModel from google.generativeai import GenerativeModel
from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types import GenerateContentResponse, content_types, safety_types
from google.generativeai.types.generation_types import BaseGenerateContentResponse from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = "" from extensions import ext_redis
class MockGoogleResponseClass: class MockGoogleResponseClass:
@ -57,11 +57,6 @@ class MockGoogleClass:
stream: bool = False, stream: bool = False,
**kwargs, **kwargs,
) -> GenerateContentResponse: ) -> GenerateContentResponse:
global current_api_key
if len(current_api_key) < 16:
raise Exception("Invalid API key")
if stream: if stream:
return MockGoogleClass.generate_content_stream() return MockGoogleClass.generate_content_stream()
@ -75,33 +70,29 @@ class MockGoogleClass:
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()] return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
if name.endswith("_async"): def mock_configure(api_key: str):
name = name.split("_")[0] if len(api_key) < 16:
cls = getattr(glm, name.title() + "ServiceAsyncClient") raise Exception("Invalid API key")
else:
cls = getattr(glm, name.title() + "ServiceClient")
# Attempt to configure using defaults.
if not self.client_config:
configure()
client_options = self.client_config.get("client_options", None) class MockFileState:
if client_options: def __init__(self):
current_api_key = client_options.api_key self.name = "FINISHED"
def nop(self, *args, **kwargs):
pass
original_init = cls.__init__ class MockGoogleFile:
cls.__init__ = nop def __init__(self, name: str = "mock_file_name"):
client: glm.GenerativeServiceClient = cls(**self.client_config) self.name = name
cls.__init__ = original_init self.state = MockFileState()
if not self.default_metadata:
return client def mock_get_file(name: str) -> MockGoogleFile:
return MockGoogleFile(name)
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
return MockGoogleFile()
@pytest.fixture @pytest.fixture
@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates) monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content) monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client) monkeypatch.setattr("google.generativeai.configure", mock_configure)
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
yield yield
monkeypatch.undo() monkeypatch.undo()
@pytest.fixture
def setup_mock_redis() -> None:
ext_redis.redis_client.get = MagicMock(return_value=None)
ext_redis.redis_client.setex = MagicMock(return_value=None)
ext_redis.redis_client.exists = MagicMock(return_value=True)

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_invoke_chat_model_with_vision(setup_google_mock): def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
result = model.invoke( result = model.invoke(
@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
model = GoogleLargeLanguageModel() model = GoogleLargeLanguageModel()
result = model.invoke( result = model.invoke(

View File

@ -326,6 +326,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test1.jpg", filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL, transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url, remote_url=fake_remote_url,
) )
@ -395,6 +396,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
tenant_id="test", tenant_id="test",
type=FileType.IMAGE, type=FileType.IMAGE,
filename="test1.jpg", filename="test1.jpg",
extension=".jpg",
transfer_method=FileTransferMethod.REMOTE_URL, transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url, remote_url=fake_remote_url,
) )