diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 27277164c9..ee09b8cb74 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -5,6 +5,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageContentType, PromptMessageRole, PromptMessageTool, @@ -31,6 +32,7 @@ And you should always end the block with a "```" to indicate the end of the JSON ```JSON""" + class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, @@ -159,7 +161,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if len(prompt_messages) == 0: raise ValueError('At least one message is required') - + if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: prompt_messages = prompt_messages[1:] @@ -185,7 +187,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): continue if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: + copy_prompt_message.role == PromptMessageRole.USER: new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: if copy_prompt_message.role == PromptMessageRole.USER: @@ -205,31 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): new_prompt_messages.append(copy_prompt_message) if model == 'glm-4v': - params = { - 'model': model, - 'messages': [{ - 'role': prompt_message.role.value, - 'content': - [ - { - 'type': 'text', - 'text': prompt_message.content - } - ] if isinstance(prompt_message.content, str) else - [ - { - 'type': 'image', - 'image_url': { - 'url': content.data - } - } if content.type == PromptMessageContentType.IMAGE else { - 'type': 'text', - 'text': content.data - } for content in prompt_message.content - ], - } for prompt_message in new_prompt_messages], - **model_parameters - } + params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = { 'model': model, @@ -277,8 +255,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): for prompt_message in new_prompt_messages: # merge system message to user message if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: + prompt_message.role == PromptMessageRole.TOOL or \ + prompt_message.role == PromptMessageRole.USER: if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': params['messages'][-1]['content'] += "\n\n" + prompt_message.content else: @@ -306,8 +284,44 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - - def _handle_generate_response(self, model: str, + + def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], + model_parameters: dict): + messages = [ + { + 'role': message.role.value, + 'content': self._construct_glm_4v_messages(message.content) + } + for message in prompt_messages + ] + + params = { + 'model': model, + 'messages': messages, + **model_parameters + } + + return params + + def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]: + if isinstance(prompt_message, str): + return [{'type': 'text', 'text': prompt_message}] + + return [ + {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} + if item.type == PromptMessageContentType.IMAGE else + {'type': 'text', 'text': item.data} + + for item in prompt_message + ] + + def _remove_image_header(self, image: str) -> str: + if image.startswith('data:image'): + return image.split(',')[1] + + return image + + def _handle_generate_response(self, model: str, credentials: dict, tools: Optional[list[PromptMessageTool]], response: Completion, @@ -338,7 +352,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ) text += choice.message.content or '' - + prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens @@ -358,7 +372,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return result - def _handle_generate_stream_response(self, model: str, + def _handle_generate_stream_response(self, model: str, credentials: dict, tools: Optional[list[PromptMessageTool]], responses: Generator[ChatCompletionChunk, None, None], @@ -380,7 +394,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): continue - + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: if tool_call.type == 'function': @@ -454,8 +468,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - - def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. @@ -473,4 +487,4 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): text += f"\n{tool.json()}" # trim off the trailing ' ' that might come from the "Assistant: " - return text.rstrip() \ No newline at end of file + return text.rstrip()