diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b54668a12d..7d19ccbb74 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ try: - ping_message = SystemPromptMessage(content="ping") + ping_message = UserPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_output_tokens": 5}) except Exception as ex: @@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): config_kwargs["stop_sequences"] = stop genai.configure(api_key=credentials["google_api_key"]) - google_model = genai.GenerativeModel(model_name=model) history = [] + system_instruction = None for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) + elif content["role"] == "system": + system_instruction = content["parts"][0] else: history.append(content) + if not history: + raise InvokeError("The user prompt message is required. You only add a system prompt message.") + + google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), @@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ) return glm_content elif isinstance(message, SystemPromptMessage): - return {"role": "user", "parts": [to_part(message.content)]} + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) + return {"role": "system", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function",