mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 00:05:54 +08:00
feat: enhance gemini models (#11497)
This commit is contained in:
parent
56cfdce453
commit
74fdc16bd1
@ -50,12 +50,12 @@ def to_prompt_message_content(
|
|||||||
else:
|
else:
|
||||||
data = _to_base64_data_string(f)
|
data = _to_base64_data_string(f)
|
||||||
|
|
||||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
|
||||||
case FileType.AUDIO:
|
case FileType.AUDIO:
|
||||||
encoded_string = _get_encoded_string(f)
|
data = _to_base64_data_string(f)
|
||||||
if f.extension is None:
|
if f.extension is None:
|
||||||
raise ValueError("Missing file extension")
|
raise ValueError("Missing file extension")
|
||||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||||
case FileType.VIDEO:
|
case FileType.VIDEO:
|
||||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||||
data = _to_url(f)
|
data = _to_url(f)
|
||||||
@ -65,14 +65,8 @@ def to_prompt_message_content(
|
|||||||
raise ValueError("Missing file extension")
|
raise ValueError("Missing file extension")
|
||||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||||
case FileType.DOCUMENT:
|
case FileType.DOCUMENT:
|
||||||
data = _get_encoded_string(f)
|
data = _to_base64_data_string(f)
|
||||||
if f.mime_type is None:
|
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
|
||||||
raise ValueError("Missing file mime_type")
|
|
||||||
return DocumentPromptMessageContent(
|
|
||||||
encode_format="base64",
|
|
||||||
mime_type=f.mime_type,
|
|
||||||
data=data,
|
|
||||||
)
|
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"file type {f.type} is not supported")
|
raise ValueError(f"file type {f.type} is not supported")
|
||||||
|
|
||||||
|
@ -101,13 +101,14 @@ class ImagePromptMessageContent(PromptMessageContent):
|
|||||||
|
|
||||||
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
||||||
detail: DETAIL = DETAIL.LOW
|
detail: DETAIL = DETAIL.LOW
|
||||||
|
format: str = Field("jpg", description="Image format")
|
||||||
|
|
||||||
|
|
||||||
class DocumentPromptMessageContent(PromptMessageContent):
|
class DocumentPromptMessageContent(PromptMessageContent):
|
||||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||||
encode_format: Literal["base64"]
|
encode_format: Literal["base64"]
|
||||||
mime_type: str
|
|
||||||
data: str
|
data: str
|
||||||
|
format: str = Field(..., description="Document format")
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(ABC, BaseModel):
|
class PromptMessage(ABC, BaseModel):
|
||||||
|
@ -526,17 +526,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
elif isinstance(message_content, DocumentPromptMessageContent):
|
elif isinstance(message_content, DocumentPromptMessageContent):
|
||||||
if message_content.mime_type != "application/pdf":
|
data_split = message_content.data.split(";base64,")
|
||||||
|
mime_type = data_split[0].replace("data:", "")
|
||||||
|
base64_data = data_split[1]
|
||||||
|
if mime_type != "application/pdf":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported document type {message_content.mime_type}, "
|
f"Unsupported document type {mime_type}, " "only support application/pdf"
|
||||||
"only support application/pdf"
|
|
||||||
)
|
)
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "document",
|
"type": "document",
|
||||||
"source": {
|
"source": {
|
||||||
"type": message_content.encode_format,
|
"type": message_content.encode_format,
|
||||||
"media_type": message_content.mime_type,
|
"media_type": mime_type,
|
||||||
"data": message_content.data,
|
"data": base64_data,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 1048576
|
context_size: 1048576
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 2097152
|
context_size: 2097152
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 2097152
|
context_size: 2097152
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 2097152
|
context_size: 2097152
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 2097152
|
context_size: 2097152
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 2097152
|
context_size: 2097152
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 2097152
|
context_size: 2097152
|
||||||
|
@ -8,6 +8,8 @@ features:
|
|||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- document
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32767
|
context_size: 32767
|
||||||
|
@ -7,6 +7,9 @@ features:
|
|||||||
- vision
|
- vision
|
||||||
- tool-call
|
- tool-call
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
|
- document
|
||||||
|
- video
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 32767
|
context_size: 32767
|
||||||
|
@ -1,29 +1,30 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union
|
||||||
|
|
||||||
import google.ai.generativelanguage as glm
|
import google.ai.generativelanguage as glm
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import requests
|
import requests
|
||||||
from google.api_core import exceptions
|
from google.api_core import exceptions
|
||||||
from google.generativeai.client import _ClientManager
|
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
|
||||||
from google.generativeai.types.content_types import to_part
|
from google.generativeai.types.content_types import to_part
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
DocumentPromptMessageContent,
|
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
|
VideoPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import (
|
from core.model_runtime.errors.invoke import (
|
||||||
InvokeAuthorizationError,
|
InvokeAuthorizationError,
|
||||||
@ -35,21 +36,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
|
||||||
"application/pdf",
|
|
||||||
"application/x-javascript",
|
|
||||||
"text/javascript",
|
|
||||||
"application/x-python",
|
|
||||||
"text/x-python",
|
|
||||||
"text/plain",
|
|
||||||
"text/html",
|
|
||||||
"text/css",
|
|
||||||
"text/md",
|
|
||||||
"text/csv",
|
|
||||||
"text/xml",
|
|
||||||
"text/rtf",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||||
@ -201,16 +188,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
if stop:
|
if stop:
|
||||||
config_kwargs["stop_sequences"] = stop
|
config_kwargs["stop_sequences"] = stop
|
||||||
|
|
||||||
|
genai.configure(api_key=credentials["google_api_key"])
|
||||||
google_model = genai.GenerativeModel(model_name=model)
|
google_model = genai.GenerativeModel(model_name=model)
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
|
||||||
if model == "gemini-pro-vision":
|
|
||||||
last_msg = prompt_messages[-1]
|
|
||||||
content = self._format_message_to_glm_content(last_msg)
|
|
||||||
history.append(content)
|
|
||||||
else:
|
|
||||||
for msg in prompt_messages: # makes message roles strictly alternating
|
for msg in prompt_messages: # makes message roles strictly alternating
|
||||||
content = self._format_message_to_glm_content(msg)
|
content = self._format_message_to_glm_content(msg)
|
||||||
if history and history[-1]["role"] == content["role"]:
|
if history and history[-1]["role"] == content["role"]:
|
||||||
@ -218,13 +200,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
history.append(content)
|
history.append(content)
|
||||||
|
|
||||||
# Create a new ClientManager with tenant's API key
|
|
||||||
new_client_manager = _ClientManager()
|
|
||||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
|
||||||
new_custom_client = new_client_manager.make_client("generative")
|
|
||||||
|
|
||||||
google_model._client = new_custom_client
|
|
||||||
|
|
||||||
response = google_model.generate_content(
|
response = google_model.generate_content(
|
||||||
contents=history,
|
contents=history,
|
||||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||||
@ -346,7 +321,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
content = message.content
|
content = message.content
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
||||||
|
|
||||||
if isinstance(message, UserPromptMessage):
|
if isinstance(message, UserPromptMessage):
|
||||||
message_text = f"{human_prompt} {content}"
|
message_text = f"{human_prompt} {content}"
|
||||||
@ -359,6 +334,44 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
|
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
||||||
|
key = f"{message_content.type.value}:{hash(message_content.data)}"
|
||||||
|
if redis_client.exists(key):
|
||||||
|
try:
|
||||||
|
return genai.get_file(redis_client.get(key).decode())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
|
if message_content.data.startswith("data:"):
|
||||||
|
metadata, base64_data = message_content.data.split(",", 1)
|
||||||
|
file_content = base64.b64decode(base64_data)
|
||||||
|
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||||
|
temp_file.write(file_content)
|
||||||
|
else:
|
||||||
|
# only ImagePromptMessageContent and VideoPromptMessageContent has url
|
||||||
|
try:
|
||||||
|
response = requests.get(message_content.data)
|
||||||
|
response.raise_for_status()
|
||||||
|
if message_content.type is ImagePromptMessageContent:
|
||||||
|
prefix = "image/"
|
||||||
|
elif message_content.type is VideoPromptMessageContent:
|
||||||
|
prefix = "video/"
|
||||||
|
mime_type = prefix + message_content.format
|
||||||
|
temp_file.write(response.content)
|
||||||
|
except Exception as ex:
|
||||||
|
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
|
||||||
|
temp_file.flush()
|
||||||
|
try:
|
||||||
|
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
|
||||||
|
while file.state.name == "PROCESSING":
|
||||||
|
time.sleep(5)
|
||||||
|
file = genai.get_file(file.name)
|
||||||
|
# google will delete your upload files in 2 days.
|
||||||
|
redis_client.setex(key, 47 * 60 * 60, file.name)
|
||||||
|
return file
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||||
"""
|
"""
|
||||||
Format a single message into glm.Content for Google API
|
Format a single message into glm.Content for Google API
|
||||||
@ -374,28 +387,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||||||
for c in message.content:
|
for c in message.content:
|
||||||
if c.type == PromptMessageContentType.TEXT:
|
if c.type == PromptMessageContentType.TEXT:
|
||||||
glm_content["parts"].append(to_part(c.data))
|
glm_content["parts"].append(to_part(c.data))
|
||||||
elif c.type == PromptMessageContentType.IMAGE:
|
|
||||||
message_content = cast(ImagePromptMessageContent, c)
|
|
||||||
if message_content.data.startswith("data:"):
|
|
||||||
metadata, base64_data = c.data.split(",", 1)
|
|
||||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
|
||||||
else:
|
else:
|
||||||
# fetch image data from url
|
glm_content["parts"].append(self._upload_file_content_to_google(c))
|
||||||
try:
|
|
||||||
image_content = requests.get(message_content.data).content
|
|
||||||
with Image.open(io.BytesIO(image_content)) as img:
|
|
||||||
mime_type = f"image/{img.format.lower()}"
|
|
||||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
|
||||||
except Exception as ex:
|
|
||||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
||||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
|
||||||
glm_content["parts"].append(blob)
|
|
||||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
|
||||||
message_content = cast(DocumentPromptMessageContent, c)
|
|
||||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
|
||||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
|
||||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
|
||||||
glm_content["parts"].append(blob)
|
|
||||||
|
|
||||||
return glm_content
|
return glm_content
|
||||||
elif isinstance(message, AssistantPromptMessage):
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
@ -920,10 +920,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||||||
}
|
}
|
||||||
sub_messages.append(sub_message_dict)
|
sub_messages.append(sub_message_dict)
|
||||||
elif isinstance(message_content, AudioPromptMessageContent):
|
elif isinstance(message_content, AudioPromptMessageContent):
|
||||||
|
data_split = message_content.data.split(";base64,")
|
||||||
|
base64_data = data_split[1]
|
||||||
sub_message_dict = {
|
sub_message_dict = {
|
||||||
"type": "input_audio",
|
"type": "input_audio",
|
||||||
"input_audio": {
|
"input_audio": {
|
||||||
"data": message_content.data,
|
"data": base64_data,
|
||||||
"format": message_content.format,
|
"format": message_content.format,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import google.generativeai.types.generation_types as generation_config_types
|
import google.generativeai.types.generation_types as generation_config_types
|
||||||
import pytest
|
import pytest
|
||||||
@ -6,11 +7,10 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||||||
from google.ai import generativelanguage as glm
|
from google.ai import generativelanguage as glm
|
||||||
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
from google.ai.generativelanguage_v1beta.types import content as gag_content
|
||||||
from google.generativeai import GenerativeModel
|
from google.generativeai import GenerativeModel
|
||||||
from google.generativeai.client import _ClientManager, configure
|
|
||||||
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
from google.generativeai.types import GenerateContentResponse, content_types, safety_types
|
||||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||||
|
|
||||||
current_api_key = ""
|
from extensions import ext_redis
|
||||||
|
|
||||||
|
|
||||||
class MockGoogleResponseClass:
|
class MockGoogleResponseClass:
|
||||||
@ -57,11 +57,6 @@ class MockGoogleClass:
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> GenerateContentResponse:
|
) -> GenerateContentResponse:
|
||||||
global current_api_key
|
|
||||||
|
|
||||||
if len(current_api_key) < 16:
|
|
||||||
raise Exception("Invalid API key")
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return MockGoogleClass.generate_content_stream()
|
return MockGoogleClass.generate_content_stream()
|
||||||
|
|
||||||
@ -75,33 +70,29 @@ class MockGoogleClass:
|
|||||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||||
return [MockGoogleResponseCandidateClass()]
|
return [MockGoogleResponseCandidateClass()]
|
||||||
|
|
||||||
def make_client(self: _ClientManager, name: str):
|
|
||||||
global current_api_key
|
|
||||||
|
|
||||||
if name.endswith("_async"):
|
def mock_configure(api_key: str):
|
||||||
name = name.split("_")[0]
|
if len(api_key) < 16:
|
||||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
raise Exception("Invalid API key")
|
||||||
else:
|
|
||||||
cls = getattr(glm, name.title() + "ServiceClient")
|
|
||||||
|
|
||||||
# Attempt to configure using defaults.
|
|
||||||
if not self.client_config:
|
|
||||||
configure()
|
|
||||||
|
|
||||||
client_options = self.client_config.get("client_options", None)
|
class MockFileState:
|
||||||
if client_options:
|
def __init__(self):
|
||||||
current_api_key = client_options.api_key
|
self.name = "FINISHED"
|
||||||
|
|
||||||
def nop(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
original_init = cls.__init__
|
class MockGoogleFile:
|
||||||
cls.__init__ = nop
|
def __init__(self, name: str = "mock_file_name"):
|
||||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
self.name = name
|
||||||
cls.__init__ = original_init
|
self.state = MockFileState()
|
||||||
|
|
||||||
if not self.default_metadata:
|
|
||||||
return client
|
def mock_get_file(name: str) -> MockGoogleFile:
|
||||||
|
return MockGoogleFile(name)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_upload_file(path: str, mime_type: str) -> MockGoogleFile:
|
||||||
|
return MockGoogleFile()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -109,8 +100,17 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
|||||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
monkeypatch.setattr("google.generativeai.configure", mock_configure)
|
||||||
|
monkeypatch.setattr("google.generativeai.get_file", mock_get_file)
|
||||||
|
monkeypatch.setattr("google.generativeai.upload_file", mock_upload_file)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
monkeypatch.undo()
|
monkeypatch.undo()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_mock_redis() -> None:
|
||||||
|
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||||
|
ext_redis.redis_client.setex = MagicMock(return_value=None)
|
||||||
|
ext_redis.redis_client.exists = MagicMock(return_value=True)
|
||||||
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
|
from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel
|
||||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock, setup_mock_redis
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||||
@ -95,7 +95,7 @@ def test_invoke_stream_model(setup_google_mock):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||||
def test_invoke_chat_model_with_vision(setup_google_mock):
|
def test_invoke_chat_model_with_vision(setup_google_mock, setup_mock_redis):
|
||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
@ -124,7 +124,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||||
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
|
def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock, setup_mock_redis):
|
||||||
model = GoogleLargeLanguageModel()
|
model = GoogleLargeLanguageModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
|
@ -326,6 +326,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
filename="test1.jpg",
|
filename="test1.jpg",
|
||||||
|
extension=".jpg",
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
remote_url=fake_remote_url,
|
remote_url=fake_remote_url,
|
||||||
)
|
)
|
||||||
@ -395,6 +396,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
|||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
filename="test1.jpg",
|
filename="test1.jpg",
|
||||||
|
extension=".jpg",
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
remote_url=fake_remote_url,
|
remote_url=fake_remote_url,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user