mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 18:19:11 +08:00
fix: Correct image parameter passing in GLM-4v model API calls (#2948)
This commit is contained in:
parent
08a5afcf9f
commit
a676d4387c
@ -5,6 +5,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
@ -31,6 +32,7 @@ And you should always end the block with a "```" to indicate the end of the JSON
|
|||||||
|
|
||||||
```JSON"""
|
```JSON"""
|
||||||
|
|
||||||
|
|
||||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
@ -205,31 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
|
|
||||||
if model == 'glm-4v':
|
if model == 'glm-4v':
|
||||||
params = {
|
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
||||||
'model': model,
|
|
||||||
'messages': [{
|
|
||||||
'role': prompt_message.role.value,
|
|
||||||
'content':
|
|
||||||
[
|
|
||||||
{
|
|
||||||
'type': 'text',
|
|
||||||
'text': prompt_message.content
|
|
||||||
}
|
|
||||||
] if isinstance(prompt_message.content, str) else
|
|
||||||
[
|
|
||||||
{
|
|
||||||
'type': 'image',
|
|
||||||
'image_url': {
|
|
||||||
'url': content.data
|
|
||||||
}
|
|
||||||
} if content.type == PromptMessageContentType.IMAGE else {
|
|
||||||
'type': 'text',
|
|
||||||
'text': content.data
|
|
||||||
} for content in prompt_message.content
|
|
||||||
],
|
|
||||||
} for prompt_message in new_prompt_messages],
|
|
||||||
**model_parameters
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
params = {
|
params = {
|
||||||
'model': model,
|
'model': model,
|
||||||
@ -307,6 +285,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
response = client.chat.completions.create(**params, **extra_model_kwargs)
|
response = client.chat.completions.create(**params, **extra_model_kwargs)
|
||||||
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||||
|
|
||||||
|
def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
'role': message.role.value,
|
||||||
|
'content': self._construct_glm_4v_messages(message.content)
|
||||||
|
}
|
||||||
|
for message in prompt_messages
|
||||||
|
]
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'model': model,
|
||||||
|
'messages': messages,
|
||||||
|
**model_parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]:
|
||||||
|
if isinstance(prompt_message, str):
|
||||||
|
return [{'type': 'text', 'text': prompt_message}]
|
||||||
|
|
||||||
|
return [
|
||||||
|
{'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}}
|
||||||
|
if item.type == PromptMessageContentType.IMAGE else
|
||||||
|
{'type': 'text', 'text': item.data}
|
||||||
|
|
||||||
|
for item in prompt_message
|
||||||
|
]
|
||||||
|
|
||||||
|
def _remove_image_header(self, image: str) -> str:
|
||||||
|
if image.startswith('data:image'):
|
||||||
|
return image.split(',')[1]
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str,
|
def _handle_generate_response(self, model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
tools: Optional[list[PromptMessageTool]],
|
tools: Optional[list[PromptMessageTool]],
|
||||||
@ -454,8 +468,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
|
def _convert_messages_to_prompt(self, messages: list[PromptMessage],
|
||||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
|
tools: Optional[list[PromptMessageTool]] = None) -> str:
|
||||||
"""
|
"""
|
||||||
:param messages: List of PromptMessage to combine.
|
:param messages: List of PromptMessage to combine.
|
||||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user