From ede69b46591dd983e9b3e56f1f6effa385f67209 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 3 Jan 2024 17:45:15 +0800 Subject: [PATCH] fix: gemini block error (#1877) Co-authored-by: chenhe --- .../__base/large_language_model.py | 8 ++++---- .../model_providers/google/llm/llm.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 6b5ea88d40..54f6f89cd8 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel): system_fingerprint = None real_model = model - for chunk in result: - try: + try: + for chunk in result: yield chunk self._trigger_new_chunk_callbacks( @@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel): if chunk.system_fingerprint: system_fingerprint = chunk.system_fingerprint - except Exception as e: - raise self._transform_invoke_error(e) + except Exception as e: + raise self._transform_invoke_error(e) self._trigger_after_invoke_callbacks( model=model, diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 9887e07a69..86e87e8c1e 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List import google.generativeai as genai import google.api_core.exceptions as exceptions import google.generativeai.client as client +from google.generativeai.types import HarmCategory, HarmBlockThreshold from google.generativeai.types import GenerateContentResponse, ContentType from google.generativeai.types.content_types import to_part @@ -124,7 +125,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): last_msg = prompt_messages[-1] content = self._format_message_to_glm_content(last_msg) history.append(content) - else: + else: for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: @@ -139,13 +140,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel): new_custom_client = new_client_manager.make_client("generative") google_model._client = new_custom_client + + safety_settings={ + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig( **config_kwargs ), - stream=stream + stream=stream, + safety_settings=safety_settings ) if stream: @@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel): content=response.text ) - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -202,11 +210,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel): for chunk in response: content = chunk.text index += 1 - + assistant_prompt_message = AssistantPromptMessage( content=content if content else '', ) - + if not response._done: # transform assistant message to prompt message