fix: Correct image parameter passing in GLM-4v model API calls (#2948)

This commit is contained in:
Weishan-0 2024-03-26 10:43:20 +08:00 committed by GitHub
parent 08a5afcf9f
commit a676d4387c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageContentType, PromptMessageContentType,
PromptMessageRole, PromptMessageRole,
PromptMessageTool, PromptMessageTool,
@ -31,6 +32,7 @@ And you should always end the block with a "```" to indicate the end of the JSON
```JSON""" ```JSON"""
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, def _invoke(self, model: str, credentials: dict,
@ -205,31 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
new_prompt_messages.append(copy_prompt_message) new_prompt_messages.append(copy_prompt_message)
if model == 'glm-4v': if model == 'glm-4v':
params = { params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
'model': model,
'messages': [{
'role': prompt_message.role.value,
'content':
[
{
'type': 'text',
'text': prompt_message.content
}
] if isinstance(prompt_message.content, str) else
[
{
'type': 'image',
'image_url': {
'url': content.data
}
} if content.type == PromptMessageContentType.IMAGE else {
'type': 'text',
'text': content.data
} for content in prompt_message.content
],
} for prompt_message in new_prompt_messages],
**model_parameters
}
else: else:
params = { params = {
'model': model, 'model': model,
@ -307,6 +285,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
response = client.chat.completions.create(**params, **extra_model_kwargs) response = client.chat.completions.create(**params, **extra_model_kwargs)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage],
model_parameters: dict):
messages = [
{
'role': message.role.value,
'content': self._construct_glm_4v_messages(message.content)
}
for message in prompt_messages
]
params = {
'model': model,
'messages': messages,
**model_parameters
}
return params
def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]:
if isinstance(prompt_message, str):
return [{'type': 'text', 'text': prompt_message}]
return [
{'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}}
if item.type == PromptMessageContentType.IMAGE else
{'type': 'text', 'text': item.data}
for item in prompt_message
]
def _remove_image_header(self, image: str) -> str:
if image.startswith('data:image'):
return image.split(',')[1]
return image
def _handle_generate_response(self, model: str, def _handle_generate_response(self, model: str,
credentials: dict, credentials: dict,
tools: Optional[list[PromptMessageTool]], tools: Optional[list[PromptMessageTool]],
@ -454,8 +468,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
return message_text return message_text
def _convert_messages_to_prompt(self, messages: list[PromptMessage],
def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: tools: Optional[list[PromptMessageTool]] = None) -> str:
""" """
:param messages: List of PromptMessage to combine. :param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags. :return: Combined string with necessary human_prompt and ai_prompt tags.