fix: remote_url doesn't work for gemini (#5090)

This commit is contained in:
rerorero 2024-06-12 14:14:53 +09:00 committed by GitHub
parent b7c72f7a97
commit 28997772a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,18 +1,22 @@
import base64
import json import json
import logging import logging
import mimetypes
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union, cast
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions import google.api_core.exceptions as exceptions
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.client as client import google.generativeai.client as client
import requests
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part from google.generativeai.types.content_types import to_part
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
PromptMessageTool, PromptMessageTool,
@ -361,11 +365,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content: for c in message.content:
if c.type == PromptMessageContentType.TEXT: if c.type == PromptMessageContentType.TEXT:
glm_content['parts'].append(to_part(c.data)) glm_content['parts'].append(to_part(c.data))
else: elif c.type == PromptMessageContentType.IMAGE:
metadata, data = c.data.split(',', 1) message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1] mime_type = metadata.split(';', 1)[0].split(':')[1]
blob = {"inline_data":{"mime_type":mime_type,"data":data}} else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}}
glm_content['parts'].append(blob) glm_content['parts'].append(blob)
return glm_content return glm_content
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
glm_content = { glm_content = {